From 97ae1254f33130b8651fcb6a7d349a76cb281e81 Mon Sep 17 00:00:00 2001 From: Vadim Voitenko Date: Sat, 25 May 2024 21:42:21 +0300 Subject: [PATCH 1/2] feat: transformation conditions implementation * added test for transformation pipelines. Revised some code * Revised async pipeline * users can access simply to the record using record.column_name to get the driver encoded value or raw_record.column_name to get a raw string record * implemented ast debugging for transformation cond * implemented expression patcher * Added doc * Revised some tests * users can access simply to the record using record.column_name to get the driver encoded value or raw_record.column_name to get a raw string record --- .gitignore | 1 - .../custom_functions/index.md | 12 +- .../advanced_transformers/index.md | 4 +- .../transformation_condition.md | 160 ++++++++++++ go.mod | 1 + go.sum | 2 + internal/db/postgres/cmd/validate.go | 36 +-- .../cmd/validate_utils/json_document_test.go | 40 +-- internal/db/postgres/context/table.go | 4 +- internal/db/postgres/context/transformers.go | 13 +- .../dumpers/transformation_pipeline.go | 62 +++-- .../dumpers/transformation_pipeline_test.go | 108 ++++++++ .../postgres/dumpers/transformation_window.go | 67 +++-- .../dumpers/transformation_window_test.go | 158 ++++++++++++ internal/db/postgres/entries/table.go | 1 + .../postgres/transformers/column_context.go | 4 +- .../db/postgres/transformers/dict_test.go | 6 + internal/db/postgres/transformers/email.go | 10 +- .../db/postgres/transformers/hash_test.go | 3 + .../db/postgres/transformers/json_context.go | 2 +- .../db/postgres/transformers/json_test.go | 3 + .../db/postgres/transformers/masking_test.go | 2 + .../postgres/transformers/noise_date_test.go | 1 + .../postgres/transformers/noise_float_test.go | 1 + .../postgres/transformers/noise_int_test.go | 1 + .../transformers/noise_numeric_test.go | 1 + .../postgres/transformers/random_bool_test.go | 1 + .../transformers/random_choice_test.go | 3 + .../postgres/transformers/random_date_test.go | 1 + .../transformers/random_float_test.go | 1 + .../postgres/transformers/random_int_test.go | 3 + .../postgres/transformers/random_ip_test.go | 1 + .../postgres/transformers/random_mac_test.go | 1 + .../transformers/random_numeric_test.go | 3 + .../transformers/random_person_test.go | 9 +- .../transformers/random_string_test.go | 1 + .../random_unix_timestamp_test.go | 3 + .../postgres/transformers/random_uuid_test.go | 1 + .../transformers/real_address_test.go | 3 + .../transformers/regexp_replace_test.go | 1 + .../db/postgres/transformers/replace_test.go | 3 + .../transformers/ro_record_context.go | 47 ---- .../db/postgres/transformers/set_null_test.go | 1 + .../transformers/template_record_test.go | 4 +- .../db/postgres/transformers/template_test.go | 2 + .../postgres/transformers/utils/definition.go | 28 +- .../transformers/utils/definition_test.go | 2 +- internal/domains/config.go | 2 + mkdocs.yml | 1 + pkg/toolkit/dynamic_parameter.go | 4 +- pkg/toolkit/expr.go | 243 ++++++++++++++++++ pkg/toolkit/expt_test.go | 67 +++++ pkg/toolkit/record_test.go | 73 +++--- pkg/toolkit/template_record_context.go | 4 +- .../{testutils_test.go => testutils.go} | 4 + pkg/toolkit/testutils/testutils.go | 68 +++++ pkg/toolkit/validation_warning.go | 42 +-- .../external_transformer}/test.go | 1 + 58 files changed, 1108 insertions(+), 223 deletions(-) create mode 100644 docs/built_in_transformers/transformation_condition.md create mode 100644 internal/db/postgres/dumpers/transformation_pipeline_test.go create mode 100644 internal/db/postgres/dumpers/transformation_window_test.go delete mode 100644 internal/db/postgres/transformers/ro_record_context.go create mode 100644 pkg/toolkit/expr.go create mode 100644 pkg/toolkit/expt_test.go rename pkg/toolkit/{testutils_test.go => testutils.go} (94%) create mode 100644 pkg/toolkit/testutils/testutils.go rename {pkg/toolkit/test => tests/external_transformer}/test.go (99%) diff --git a/.gitignore b/.gitignore index 5e82a3dd..0c9302eb 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,3 @@ venv .cache # Binaries cmd/greenmask/greenmask -pkg/toolkit/test/test diff --git a/docs/built_in_transformers/advanced_transformers/custom_functions/index.md b/docs/built_in_transformers/advanced_transformers/custom_functions/index.md index 42506cd7..1cf4f6fc 100644 --- a/docs/built_in_transformers/advanced_transformers/custom_functions/index.md +++ b/docs/built_in_transformers/advanced_transformers/custom_functions/index.md @@ -1,6 +1,11 @@ # Template custom functions -Within Greenmask, custom functions play a crucial role, providing a wide array of options for implementing diverse logic. Under the hood, the custom functions are based on the [sprig Go's template functions](https://masterminds.github.io/sprig/). Greenmask enhances this capability by introducing additional functions and transformation functions. These extensions mirror the logic found in the [standard transformers](../../standard_transformers/index.md) but offer you the flexibility to implement intricate and comprehensive logic tailored to your specific needs. +Within Greenmask, custom functions play a crucial role, providing a wide array of options for implementing diverse +logic. Under the hood, the custom functions are based on +the [sprig Go's template functions](https://masterminds.github.io/sprig/). Greenmask enhances this capability by +introducing additional functions and transformation functions. These extensions mirror the logic found in +the [standard transformers](../../standard_transformers/index.md) but offer you the flexibility to implement intricate +and comprehensive logic tailored to your specific needs. Currently, you can use template custom functions for the [advanced transformers](../index.md): @@ -8,7 +13,10 @@ Currently, you can use template custom functions for the [advanced transformers] * [Template](../template.md) * [TemplateRecord](../template_record.md) +and for the [Transformation condition feature](../../transformation_condition.md) as well. + Custom functions are arbitrarily divided into 2 groups: -- [Core functions](core_functions.md) — custom functions that vary in purpose and include PostgreSQL driver, JSON output, testing, and transformation functions. +- [Core functions](core_functions.md) — custom functions that vary in purpose and include PostgreSQL driver, JSON + output, testing, and transformation functions. - [Faker functions](faker_function.md) — custom function of a *faker* type which generate synthetic data. diff --git a/docs/built_in_transformers/advanced_transformers/index.md b/docs/built_in_transformers/advanced_transformers/index.md index d72dce44..7c6666e3 100644 --- a/docs/built_in_transformers/advanced_transformers/index.md +++ b/docs/built_in_transformers/advanced_transformers/index.md @@ -5,6 +5,6 @@ Advanced transformers are modifiable anonymization methods that users can adjust Below you can find an index of all advanced transformers currently available in Greenmask. 1. [Json](json.md) — changes a JSON content by using `delete` and `set` operations. -1. [Template](template.md) — executes a Go template of your choice and applies the result to a specified column. -1. [TemplateRecord](template_record.md) — modifies records by using a Go template of your choice and applies the changes via the PostgreSQL +2. [Template](template.md) — executes a Go template of your choice and applies the result to a specified column. +3. [TemplateRecord](template_record.md) — modifies records by using a Go template of your choice and applies the changes via the PostgreSQL driver. diff --git a/docs/built_in_transformers/transformation_condition.md b/docs/built_in_transformers/transformation_condition.md new file mode 100644 index 00000000..6b0b1cd4 --- /dev/null +++ b/docs/built_in_transformers/transformation_condition.md @@ -0,0 +1,160 @@ +# Transformation Condition + +## Description + +The transformation condition feature allows you to execute a defined transformation only if a specified condition is +met. +The condition must be defined as a boolean expression that evaluates to `true` or `false`. Greenmask uses +[expr-lang/expr](https://github.com/expr-lang/expr) under the hood. You can use all functions and syntax provided by the +`expr` library. + +You can use the same functions that are described in +the [built-in transformers](/docs/built_in_transformers/advanced_transformers/custom_functions/index.md) + +The transformers are executed one by one - this helps you create complex transformation pipelines. For instance +depending on value chosen in the previous transformer, you can decide to execute the next transformer or not. + +## Record descriptors + +To improve the user experience, Greenmask offers special namespaces for accessing values in different formats: either +the driver-encoded value in its real type or as a raw string. + +- **`record`**: This namespace provides the record value in its actual type. +- **`raw_record`**: This namespace provides the record value as a string. + +You can access a specific column’s value using `record.column_name` for the real type or `raw_record.column_name` for +the raw string value. + +!!! warning + + A record may always be modified by previous transformers before the condition is evaluated. This means Greenmask does + not retain the original record value and instead provides the current modified value for condition evaluation. + +## Null values condition + +To check if the value is null, you can use `null` value for the comparisson. This operation works compatibly +with SQL operator `IS NULL` or `IS NOT NULL`. + +```text title="Is null cond example" +record.accountnumber == null && record.date > now() +``` + +```text title="Is not null cond example" +record.accountnumber != null && record.date <= now() +``` + +## Expression scope + +Expression scope can be on table or specific transformer. If you define the condition on the table scope, then the +condition will be evaluated before any transformer is executed. If you define the condition on the transformer scope, +then the condition will be evaluated before the specified transformer is executed. + +```yaml title="Table scope" +- schema: "purchasing" + name: "vendor" + when: 'record.accountnumber == null || record.accountnumber == "ALLENSON0001"' + transformers: + - name: "RandomString" + params: + column: "accountnumber" + min_length: 9 + max_length: 12 + symbols: "1234567890ABCDEFGHIJKLMNOPQRSTUVWXYZ" +``` + +```yaml title="Transformer scope" +- schema: "purchasing" + name: "vendor" + transformers: + - name: "RandomString" + when: 'record.accountnumber != null || record.accountnumber == "ALLENSON0001"' + params: + column: "accountnumber" + min_length: 9 + max_length: 12 + symbols: "1234567890ABCDEFGHIJKLMNOPQRSTUVWXYZ" +``` + +## Int and float value definition + +It is important to create the integer or float value in the correct format. If you want to define the integer value you +must write a number without dot (`1`, `2`, etc.). If you want to define the float value you must write a number with +dot (`1.0`, `2.0`, etc.). + +!!! warning + + You may see a wrong comparison result if you compare int and float, for example `1 == 1.0` will return `false`. + +## Architecture + +Greenmask encodes the way only when evaluating the condition - this allows to optimize the performance of the +transformation if you have a lot of conditions that uses or (`||`) or and (`&&`) operators. + +## Example: Chose random value and execute one of + +In the following example, the `RandomChoice` transformer is used to choose a random value from the list of values. +Depending on the chosen value, the `Replace` transformer is executed to set the `activeflag` column to `true` or +`false`. + +In this case the condition scope is on the transformer level. + +```yaml +- schema: "purchasing" + name: "vendor" + transformers: + - name: "RandomChoice" + params: + column: "name" + values: + - "test1" + - "test2" + + - name: "Replace" + when: 'record.name == "test1"' + params: + column: "activeflag" + value: "false" + + - name: "Replace" + when: 'record.name == "test2"' + params: + column: "activeflag" + value: "true" +``` + +## Example: Do not transform specific columns + +In the following example, the `RandomString` transformer is executed only if the `businessentityid` column value is not +equal to `1492` or `1`. + +```yaml + - schema: "purchasing" + name: "vendor" + when: '!(record.businessentityid | has([1492, 1]))' + transformers: + - name: "RandomString" + params: + column: "accountnumber" + min_length: 9 + max_length: 12 + symbols: "1234567890ABCDEFGHIJKLMNOPQRSTUVWXYZ" +``` + +## Example: Check the json attribute value + +In the following example, the `RandomString` transformer is executed only if the `a` attribute in the `json_data` column +is equal to `1`. + +```yaml +- schema: "public" + name: "jsondata" + when: 'raw_record.json_data | jsonGet("a") == 1' + transformers: + - name: "RandomString" + params: + column: "accountnumber" + min_length: 9 + max_length: 12 + symbols: "1234567890ABCDEFGHIJKLMNOPQRSTUVWXYZ" +``` + diff --git a/go.mod b/go.mod index 7d1a2225..450fde7d 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/Masterminds/sprig/v3 v3.3.0 github.com/aws/aws-sdk-go v1.55.5 github.com/dchest/siphash v1.2.3 + github.com/expr-lang/expr v1.16.7 github.com/ggwhite/go-masker v1.1.0 github.com/go-faker/faker/v4 v4.5.0 github.com/google/uuid v1.6.0 diff --git a/go.sum b/go.sum index b90501b2..393984a4 100644 --- a/go.sum +++ b/go.sum @@ -16,6 +16,8 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dchest/siphash v1.2.3 h1:QXwFc8cFOR2dSa/gE6o/HokBMWtLUaNDVd+22aKHeEA= github.com/dchest/siphash v1.2.3/go.mod h1:0NvQU092bT0ipiFN++/rXm69QG9tVxLAlQHIXMPAkHc= +github.com/expr-lang/expr v1.16.7 h1:gCIiHt5ODA0xIaDbD0DPKyZpM9Drph3b3lolYAYq2Kw= +github.com/expr-lang/expr v1.16.7/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= diff --git a/internal/db/postgres/cmd/validate.go b/internal/db/postgres/cmd/validate.go index 252866c3..7a7a61da 100644 --- a/internal/db/postgres/cmd/validate.go +++ b/internal/db/postgres/cmd/validate.go @@ -130,9 +130,15 @@ func (v *Validate) Run(ctx context.Context) (int, error) { return nonZeroExitCode, fmt.Errorf("unable to build runtime context: %w", err) } - if err = v.printValidationWarnings(); err != nil { + err = toolkit.PrintValidationWarnings( + v.context.Warnings, v.config.Validate.ResolvedWarnings, v.config.Validate.Warnings, + ) + if err != nil { return nonZeroExitCode, err } + if v.context.IsFatal() { + return nonZeroExitCode, fmt.Errorf("fatal validation error") + } if err = v.diffWithPreviousSchema(ctx); err != nil { return nonZeroExitCode, err @@ -280,34 +286,6 @@ func (v *Validate) createDocument(ctx context.Context, t *entries.Table) (valida return doc, nil } -func (v *Validate) printValidationWarnings() error { - // TODO: Implement warnings hook, such as logging and HTTP sender - for _, w := range v.context.Warnings { - w.MakeHash() - if idx := slices.Index(v.config.Validate.ResolvedWarnings, w.Hash); idx != -1 { - log.Debug().Str("hash", w.Hash).Msg("resolved warning has been excluded") - if w.Severity == toolkit.ErrorValidationSeverity { - return fmt.Errorf("warning with hash %s cannot be excluded because it is an error", w.Hash) - } - continue - } - - if w.Severity == toolkit.ErrorValidationSeverity { - // The warnings with error severity must be printed anyway - log.Error().Any("ValidationWarning", w).Msg("") - } else { - // Print warnings with severity level lower than ErrorValidationSeverity only if requested - if v.config.Validate.Warnings { - log.Warn().Any("ValidationWarning", w).Msg("") - } - } - } - if v.context.IsFatal() { - return fmt.Errorf("fatal validation error") - } - return nil -} - func (v *Validate) getTablesToValidate() ([]*domains.Table, error) { var tablesToValidate []*domains.Table for _, tv := range v.config.Validate.Tables { diff --git a/internal/db/postgres/cmd/validate_utils/json_document_test.go b/internal/db/postgres/cmd/validate_utils/json_document_test.go index 0f74db51..d3987716 100644 --- a/internal/db/postgres/cmd/validate_utils/json_document_test.go +++ b/internal/db/postgres/cmd/validate_utils/json_document_test.go @@ -15,26 +15,6 @@ import ( "github.com/greenmaskio/greenmask/pkg/toolkit" ) -type testTransformer struct{} - -func (tt *testTransformer) Init(ctx context.Context) error { - return nil -} - -func (tt *testTransformer) Done(ctx context.Context) error { - return nil -} - -func (tt *testTransformer) Transform(ctx context.Context, r *toolkit.Record) (*toolkit.Record, error) { - return nil, nil -} - -func (tt *testTransformer) GetAffectedColumns() map[int]string { - return map[int]string{ - 1: "name", - } -} - func TestJsonDocument_GetAffectedColumns(t *testing.T) { tab, _, _ := getTableAndRows() jd := NewJsonDocument(tab, true, true) @@ -87,6 +67,26 @@ func TestJsonDocument_GetRecords(t *testing.T) { //r.SetRow(row) } +type testTransformer struct{} + +func (tt *testTransformer) Init(ctx context.Context) error { + return nil +} + +func (tt *testTransformer) Done(ctx context.Context) error { + return nil +} + +func (tt *testTransformer) Transform(ctx context.Context, r *toolkit.Record) (*toolkit.Record, error) { + return nil, nil +} + +func (tt *testTransformer) GetAffectedColumns() map[int]string { + return map[int]string{ + 1: "name", + } +} + func getTableAndRows() (table *entries.Table, original, transformed [][]byte) { tableDef := ` diff --git a/internal/db/postgres/context/table.go b/internal/db/postgres/context/table.go index 321f5260..f61c312f 100644 --- a/internal/db/postgres/context/table.go +++ b/internal/db/postgres/context/table.go @@ -124,7 +124,7 @@ func validateAndBuildTablesConfig( // InitTransformation toolkit if len(tableCfg.Transformers) > 0 { for _, tc := range tableCfg.Transformers { - transformer, initWarnings, err := initTransformer(ctx, driver, tc, registry, types) + transformer, initWarnings, err := initTransformer(ctx, driver, tc, registry) if len(initWarnings) > 0 { for _, w := range initWarnings { // Enriching the tables context into meta @@ -155,6 +155,7 @@ func validateAndBuildTablesConfig( func getTable(ctx context.Context, tx pgx.Tx, t *domains.Table) ([]*entries.Table, toolkit.ValidationWarnings, error) { table := &entries.Table{ Table: &toolkit.Table{}, + When: t.When, } var warnings toolkit.ValidationWarnings var tables []*entries.Table @@ -204,6 +205,7 @@ func getTable(ctx context.Context, tx pgx.Tx, t *domains.Table) ([]*entries.Tabl RootPtSchema: table.Schema, RootPtName: table.Name, RootOid: table.Oid, + When: table.When, } if err = rows.Scan(&pt.Oid, &pt.Schema, &pt.Name); err != nil { return nil, nil, fmt.Errorf("error scanning TableGetChildPatsQuery: %w", err) diff --git a/internal/db/postgres/context/transformers.go b/internal/db/postgres/context/transformers.go index 0c611ad5..dc6739cd 100644 --- a/internal/db/postgres/context/transformers.go +++ b/internal/db/postgres/context/transformers.go @@ -27,7 +27,6 @@ func initTransformer( ctx context.Context, d *toolkit.Driver, c *domains.TransformerConfig, r *transformersUtils.TransformerRegistry, - types []*toolkit.Type, ) (*transformersUtils.TransformerContext, toolkit.ValidationWarnings, error) { var totalWarnings toolkit.ValidationWarnings td, ok := r.Get(c.Name) @@ -35,14 +34,14 @@ func initTransformer( totalWarnings = append(totalWarnings, toolkit.NewValidationWarning(). SetMsg("transformer not found"). - SetSeverity(toolkit.ErrorValidationSeverity).SetTrace(&toolkit.Trace{ - SchemaName: d.Table.Schema, - TableName: d.Table.Name, - TransformerName: c.Name, - })) + AddMeta("SchemaName", d.Table.Schema). + AddMeta("TableName", d.Table.Name). + AddMeta("TransformerName", c.Name). + SetSeverity(toolkit.ErrorValidationSeverity), + ) return nil, totalWarnings, nil } - transformer, warnings, err := td.Instance(ctx, d, c.Params, c.DynamicParams) + transformer, warnings, err := td.Instance(ctx, d, c.Params, c.DynamicParams, c.When) if err != nil { return nil, nil, fmt.Errorf("unable to init transformer: %w", err) } diff --git a/internal/db/postgres/dumpers/transformation_pipeline.go b/internal/db/postgres/dumpers/transformation_pipeline.go index f083253e..ece19c13 100644 --- a/internal/db/postgres/dumpers/transformation_pipeline.go +++ b/internal/db/postgres/dumpers/transformation_pipeline.go @@ -30,11 +30,9 @@ import ( "github.com/greenmaskio/greenmask/pkg/toolkit" ) -const tmpFilePath = "/tmp" - var endOfLineSeq = []byte("\n") -type TransformationFunc func(ctx context.Context, r *toolkit.Record) (*toolkit.Record, error) +type transformationFunc func(ctx context.Context, r *toolkit.Record) (*toolkit.Record, error) type TransformationPipeline struct { table *entries.Table @@ -42,15 +40,17 @@ type TransformationPipeline struct { w io.Writer line uint64 row *pgcopy.Row - transformationWindows []*TransformationWindow - Transform TransformationFunc + transformationWindows []*transformationWindow + Transform transformationFunc isAsync bool record *toolkit.Record + // when - table level when condition + when *toolkit.WhenCond } func NewTransformationPipeline(ctx context.Context, eg *errgroup.Group, table *entries.Table, w io.Writer) (*TransformationPipeline, error) { - var tws []*TransformationWindow + var tws []*transformationWindow var isAsync bool // TODO: Fix this hint. Async execution cannot be performed with template record because it is unsafe. @@ -62,13 +62,13 @@ func NewTransformationPipeline(ctx context.Context, eg *errgroup.Group, table *e if !hasTemplateRecordTransformer && table.HasCustomTransformer() && len(table.TransformersContext) > 1 { isAsync = true - tw := NewTransformationWindow(ctx, eg) + tw := newTransformationWindow(ctx, eg) tws = append(tws, tw) for _, t := range table.TransformersContext { - if !tw.TryAdd(table, t.Transformer) { - tw = NewTransformationWindow(ctx, eg) + if !tw.tryAdd(table, t) { + tw = newTransformationWindow(ctx, eg) tws = append(tws, tw) - tw.TryAdd(table, t.Transformer) + tw.tryAdd(table, t) } } } @@ -90,12 +90,26 @@ func NewTransformationPipeline(ctx context.Context, eg *errgroup.Group, table *e record: record, } - var tf TransformationFunc = tp.TransformSync + var tf transformationFunc = tp.TransformSync if isAsync { tf = tp.TransformAsync } tp.Transform = tf + mata := map[string]any{ + "TableSchema": table.Schema, + "TableName": table.Name, + } + + whenCond, warnings := toolkit.NewWhenCond(table.When, table.Driver, mata) + if err := toolkit.PrintValidationWarnings(warnings, nil, true); err != nil { + return nil, err + } + if warnings.IsFatal() { + return nil, fmt.Errorf("unable to compile when condition: fatal error") + } + tp.when = whenCond + return tp, nil } @@ -123,7 +137,7 @@ func (tp *TransformationPipeline) Init(ctx context.Context) error { } if tp.isAsync { for _, w := range tp.transformationWindows { - w.Init() + w.init() } } @@ -131,8 +145,14 @@ func (tp *TransformationPipeline) Init(ctx context.Context) error { } func (tp *TransformationPipeline) TransformSync(ctx context.Context, r *toolkit.Record) (*toolkit.Record, error) { - var err error for _, t := range tp.table.TransformersContext { + needTransform, err := t.EvaluateWhen(r) + if err != nil { + return nil, NewDumpError(tp.table.Schema, tp.table.Name, tp.line, fmt.Errorf("error evaluating when condition: %w", err)) + } + if !needTransform { + continue + } _, err = t.Transformer.Transform(ctx, r) if err != nil { return nil, NewDumpError(tp.table.Schema, tp.table.Name, tp.line, err) @@ -159,13 +179,21 @@ func (tp *TransformationPipeline) Dump(ctx context.Context, data []byte) (err er } tp.record.SetRow(tp.row) - _, err = tp.Transform(ctx, tp.record) + needTransform, err := tp.when.Evaluate(tp.record) if err != nil { - return NewDumpError(tp.table.Schema, tp.table.Name, tp.line, err) + return NewDumpError(tp.table.Schema, tp.table.Name, tp.line, fmt.Errorf("error evaluating when condition: %w", err)) } + + if needTransform { + _, err = tp.Transform(ctx, tp.record) + if err != nil { + return NewDumpError(tp.table.Schema, tp.table.Name, tp.line, err) + } + } + rowDriver, err := tp.record.Encode() if err != nil { - return NewDumpError(tp.table.Schema, tp.table.Name, tp.line, fmt.Errorf("error enocding to RowDriver: %w", err)) + return NewDumpError(tp.table.Schema, tp.table.Name, tp.line, fmt.Errorf("error enocding Record to RowDriver: %w", err)) } res, err := rowDriver.Encode() if err != nil { @@ -204,7 +232,7 @@ func (tp *TransformationPipeline) Done(ctx context.Context) error { } if tp.isAsync { for _, w := range tp.transformationWindows { - w.Done() + w.close() } } diff --git a/internal/db/postgres/dumpers/transformation_pipeline_test.go b/internal/db/postgres/dumpers/transformation_pipeline_test.go new file mode 100644 index 00000000..77b35678 --- /dev/null +++ b/internal/db/postgres/dumpers/transformation_pipeline_test.go @@ -0,0 +1,108 @@ +package dumpers + +import ( + "bytes" + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + + "github.com/greenmaskio/greenmask/internal/db/postgres/transformers/utils" + "github.com/greenmaskio/greenmask/pkg/toolkit" +) + +func TestTransformationPipeline_Dump(t *testing.T) { + termCtx, termCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer termCancel() + table := getTable() + ctx := context.Background() + eg, gtx := errgroup.WithContext(ctx) + driver := getDriver(table.Table) + table.Driver = driver + when, warns := toolkit.NewWhenCond("", driver, nil) + require.Empty(t, warns) + tt := &testTransformer{} + tc := &utils.TransformerContext{ + Transformer: tt, + When: when, + } + table.TransformersContext = []*utils.TransformerContext{tc} + + buf := bytes.NewBuffer(nil) + + pipeline, err := NewTransformationPipeline(gtx, eg, table, buf) + require.NoError(t, err) + require.NoError(t, pipeline.Init(termCtx)) + data := []byte("1\t2023-08-27 00:00:00.000000") + err = pipeline.Dump(ctx, data) + require.NoError(t, err) + require.NoError(t, pipeline.Done(termCtx)) + require.NoError(t, pipeline.CompleteDump()) + require.Equal(t, tt.callsCount, 1) + require.Equal(t, buf.String(), "2\t2023-08-27 00:00:00.00000\n\\.\n\n") +} + +func TestTransformationPipeline_Dump_with_transformer_cond(t *testing.T) { + termCtx, termCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer termCancel() + table := getTable() + ctx := context.Background() + eg, gtx := errgroup.WithContext(ctx) + driver := getDriver(table.Table) + table.Driver = driver + when, warns := toolkit.NewWhenCond("record.id != 1", driver, make(map[string]any)) + require.Empty(t, warns) + tt := &testTransformer{} + tc := &utils.TransformerContext{ + Transformer: tt, + When: when, + } + table.TransformersContext = []*utils.TransformerContext{tc} + + buf := bytes.NewBuffer(nil) + + pipeline, err := NewTransformationPipeline(gtx, eg, table, buf) + require.NoError(t, err) + require.NoError(t, pipeline.Init(termCtx)) + data := []byte("1\t2023-08-27 00:00:00.000000") + err = pipeline.Dump(ctx, data) + require.NoError(t, err) + require.NoError(t, pipeline.Done(termCtx)) + require.NoError(t, pipeline.CompleteDump()) + require.Equal(t, tt.callsCount, 0) + require.Equal(t, buf.String(), "1\t2023-08-27 00:00:00.00000\n\\.\n\n") +} + +func TestTransformationPipeline_Dump_with_table_cond(t *testing.T) { + termCtx, termCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer termCancel() + table := getTable() + ctx := context.Background() + eg, gtx := errgroup.WithContext(ctx) + driver := getDriver(table.Table) + table.Driver = driver + when, warns := toolkit.NewWhenCond("", driver, make(map[string]any)) + require.Empty(t, warns) + tt := &testTransformer{} + tc := &utils.TransformerContext{ + Transformer: tt, + When: when, + } + table.TransformersContext = []*utils.TransformerContext{tc} + table.When = "record.id != 1" + + buf := bytes.NewBuffer(nil) + + pipeline, err := NewTransformationPipeline(gtx, eg, table, buf) + require.NoError(t, err) + require.NoError(t, pipeline.Init(termCtx)) + data := []byte("1\t2023-08-27 00:00:00.000000") + err = pipeline.Dump(ctx, data) + require.NoError(t, err) + require.NoError(t, pipeline.Done(termCtx)) + require.NoError(t, pipeline.CompleteDump()) + require.Equal(t, tt.callsCount, 0) + require.Equal(t, buf.String(), "1\t2023-08-27 00:00:00.00000\n\\.\n\n") +} diff --git a/internal/db/postgres/dumpers/transformation_window.go b/internal/db/postgres/dumpers/transformation_window.go index eb1a8a4e..2fd72b10 100644 --- a/internal/db/postgres/dumpers/transformation_window.go +++ b/internal/db/postgres/dumpers/transformation_window.go @@ -16,6 +16,7 @@ package dumpers import ( "context" + "fmt" "sync" "golang.org/x/sync/errgroup" @@ -25,20 +26,23 @@ import ( "github.com/greenmaskio/greenmask/pkg/toolkit" ) -type TransformationWindow struct { +type asyncContext struct { + tc *utils.TransformerContext + ch chan struct{} +} + +type transformationWindow struct { affectedColumns map[string]struct{} - transformers []utils.Transformer - chs []chan struct{} + window []*asyncContext done chan struct{} wg *sync.WaitGroup eg *errgroup.Group r *toolkit.Record ctx context.Context - size int } -func NewTransformationWindow(ctx context.Context, eg *errgroup.Group) *TransformationWindow { - return &TransformationWindow{ +func newTransformationWindow(ctx context.Context, eg *errgroup.Group) *transformationWindow { + return &transformationWindow{ affectedColumns: map[string]struct{}{}, done: make(chan struct{}, 1), wg: &sync.WaitGroup{}, @@ -47,12 +51,11 @@ func NewTransformationWindow(ctx context.Context, eg *errgroup.Group) *Transform } } -func (tw *TransformationWindow) TryAdd(table *entries.Table, t utils.Transformer) bool { +func (tw *transformationWindow) tryAdd(table *entries.Table, t *utils.TransformerContext) bool { - affectedColumn := t.GetAffectedColumns() + affectedColumn := t.Transformer.GetAffectedColumns() if len(affectedColumn) == 0 { - if len(tw.transformers) == 0 { - tw.transformers = append(tw.transformers, t) + if len(tw.window) == 0 { for _, c := range table.Columns { tw.affectedColumns[c.Name] = struct{}{} } @@ -68,20 +71,20 @@ func (tw *TransformationWindow) TryAdd(table *entries.Table, t utils.Transformer for _, name := range affectedColumn { tw.affectedColumns[name] = struct{}{} } - tw.transformers = append(tw.transformers, t) } - ch := make(chan struct{}, 1) - tw.chs = append(tw.chs, ch) - tw.size++ + tw.window = append(tw.window, &asyncContext{ + tc: t, + ch: make(chan struct{}, 1), + }) return true } -func (tw *TransformationWindow) Init() { - for idx, t := range tw.transformers { - ch := tw.chs[idx] - func(t utils.Transformer, ch chan struct{}) { +// init - runs all transformers in the goroutines and waits for the ac.ch signal to run the transformer +func (tw *transformationWindow) init() { + for _, ac := range tw.window { + func(ac *asyncContext) { tw.eg.Go(func() error { for { select { @@ -89,9 +92,9 @@ func (tw *TransformationWindow) Init() { return tw.ctx.Err() case <-tw.done: return nil - case <-ch: + case <-ac.ch: } - _, err := t.Transform(tw.ctx, tw.r) + _, err := ac.tc.Transformer.Transform(tw.ctx, tw.r) if err != nil { tw.wg.Done() return err @@ -99,26 +102,36 @@ func (tw *TransformationWindow) Init() { tw.wg.Done() } }) - }(t, ch) + }(ac) } } -func (tw *TransformationWindow) Done() { +// close - closes the done channel to stop the transformers goroutines +func (tw *transformationWindow) close() { close(tw.done) } -func (tw *TransformationWindow) Transform(ctx context.Context, r *toolkit.Record) (*toolkit.Record, error) { - - tw.wg.Add(tw.size) +// Transform - runs the transformation for the record in the window. This function checks the when +// condition of the transformer and if true sends a signal to the transformer goroutine to run the transformation +func (tw *transformationWindow) Transform(ctx context.Context, r *toolkit.Record) (*toolkit.Record, error) { tw.r = r - for _, ch := range tw.chs { + for _, ac := range tw.window { + needTransform, err := ac.tc.EvaluateWhen(r) + if err != nil { + return nil, fmt.Errorf("error evaluating when condition: %w", err) + } + if !needTransform { + continue + } + + tw.wg.Add(1) select { case <-ctx.Done(): return nil, ctx.Err() case <-tw.ctx.Done(): return nil, tw.ctx.Err() - case ch <- struct{}{}: + case ac.ch <- struct{}{}: } } diff --git a/internal/db/postgres/dumpers/transformation_window_test.go b/internal/db/postgres/dumpers/transformation_window_test.go new file mode 100644 index 00000000..1b411c9f --- /dev/null +++ b/internal/db/postgres/dumpers/transformation_window_test.go @@ -0,0 +1,158 @@ +package dumpers + +import ( + "context" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + + "github.com/greenmaskio/greenmask/internal/db/postgres/entries" + "github.com/greenmaskio/greenmask/internal/db/postgres/transformers/utils" + "github.com/greenmaskio/greenmask/pkg/toolkit" + "github.com/greenmaskio/greenmask/pkg/toolkit/testutils" +) + +func TestTransformationWindow_tryAdd(t *testing.T) { + ctx := context.Background() + eg, gtx := errgroup.WithContext(ctx) + tw := newTransformationWindow(gtx, eg) + tc := utils.TransformerContext{ + Transformer: &testTransformer{}, + } + table := getTable() + require.True(t, tw.tryAdd(table, &tc)) + require.False(t, tw.tryAdd(table, &tc)) +} + +func TestTransformationWindow_Transform(t *testing.T) { + mainCtx, mainCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer mainCancel() + eg, gtx := errgroup.WithContext(mainCtx) + tw := newTransformationWindow(gtx, eg) + when, warns := toolkit.NewWhenCond("", nil, nil) + require.Empty(t, warns) + tc := utils.TransformerContext{ + Transformer: &testTransformer{}, + When: when, + } + table := getTable() + require.True(t, tw.tryAdd(table, &tc)) + + driver := getDriver(table.Table) + record := toolkit.NewRecord(driver) + row := testutils.NewTestRowDriver([]string{"1", "2023-08-27 00:00:00.000000"}) + record.SetRow(row) + tw.init() + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + _, err := tw.Transform(ctx, record) + require.NoError(t, err) + v, err := record.GetRawColumnValueByName("id") + require.NoError(t, err) + require.False(t, v.IsNull) + require.Equal(t, []byte("2"), v.Data) + tw.close() + require.NoError(t, eg.Wait()) +} + +func TestTransformationWindow_Transform_with_cond(t *testing.T) { + table := getTable() + driver := getDriver(table.Table) + record := toolkit.NewRecord(driver) + when, warns := toolkit.NewWhenCond("record.id != 1", driver, make(map[string]any)) + require.Empty(t, warns) + mainCtx, mainCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer mainCancel() + eg, gtx := errgroup.WithContext(mainCtx) + tw := newTransformationWindow(gtx, eg) + tt := &testTransformer{} + tc := utils.TransformerContext{ + Transformer: tt, + When: when, + } + require.True(t, tw.tryAdd(table, &tc)) + + row := testutils.NewTestRowDriver([]string{"1", "2023-08-27 00:00:00.000000"}) + record.SetRow(row) + tw.init() + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + _, err := tw.Transform(ctx, record) + require.NoError(t, err) + require.Equal(t, 0, tt.callsCount) + v, err := record.GetRawColumnValueByName("id") + require.NoError(t, err) + require.False(t, v.IsNull) + require.Equal(t, []byte("1"), v.Data) + tw.close() + require.NoError(t, eg.Wait()) +} + +type testTransformer struct { + callsCount int +} + +func (tt *testTransformer) Init(ctx context.Context) error { + return nil +} + +func (tt *testTransformer) Done(ctx context.Context) error { + return nil +} + +func (tt *testTransformer) Transform(ctx context.Context, r *toolkit.Record) (*toolkit.Record, error) { + tt.callsCount++ + err := r.SetColumnValueByName("id", 2) + if err != nil { + return nil, err + } + return r, nil +} + +func (tt *testTransformer) GetAffectedColumns() map[int]string { + return map[int]string{ + 1: "name", + } +} + +func getDriver(table *toolkit.Table) *toolkit.Driver { + driver, _, err := toolkit.NewDriver(table, nil) + if err != nil { + panic(err.Error()) + } + return driver +} + +func getTable() *entries.Table { + return &entries.Table{ + Table: &toolkit.Table{ + Schema: "public", + Name: "test", + Oid: 1224, + Columns: []*toolkit.Column{ + { + Name: "id", + TypeName: "int2", + TypeOid: pgtype.Int2OID, + Num: 1, + NotNull: true, + Length: -1, + }, + { + Name: "created_at", + TypeName: "timestamp", + TypeOid: pgtype.TimestampOID, + Num: 1, + NotNull: true, + Length: -1, + }, + }, + Constraints: []toolkit.Constraint{}, + }, + } +} diff --git a/internal/db/postgres/entries/table.go b/internal/db/postgres/entries/table.go index 629fb4a9..ef5d9db9 100644 --- a/internal/db/postgres/entries/table.go +++ b/internal/db/postgres/entries/table.go @@ -46,6 +46,7 @@ type Table struct { Driver *toolkit.Driver Scores int64 SubsetConds []string + When string } func (t *Table) HasCustomTransformer() bool { diff --git a/internal/db/postgres/transformers/column_context.go b/internal/db/postgres/transformers/column_context.go index e522e69a..dbc6f455 100644 --- a/internal/db/postgres/transformers/column_context.go +++ b/internal/db/postgres/transformers/column_context.go @@ -49,7 +49,7 @@ func (cc *ColumnContext) GetValue() (any, error) { } func (cc *ColumnContext) GetRawValue() (any, error) { - return cc.rc.GetRawColumnValue(cc.columnName) + return cc.rc.GetColumnRawValue(cc.columnName) } func (cc *ColumnContext) GetColumnValue(name string) (any, error) { @@ -57,7 +57,7 @@ func (cc *ColumnContext) GetColumnValue(name string) (any, error) { } func (cc *ColumnContext) GetColumnRawValue(name string) (any, error) { - return cc.rc.GetRawColumnValue(name) + return cc.rc.GetColumnRawValue(name) } func (cc *ColumnContext) EncodeValue(v any) (any, error) { diff --git a/internal/db/postgres/transformers/dict_test.go b/internal/db/postgres/transformers/dict_test.go index 8eec7890..d1744248 100644 --- a/internal/db/postgres/transformers/dict_test.go +++ b/internal/db/postgres/transformers/dict_test.go @@ -41,6 +41,7 @@ func TestDictTransformer_Transform_with_fail(t *testing.T) { context.Background(), driver, params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) @@ -85,6 +86,7 @@ func TestDictTransformer_Transform_validation_error(t *testing.T) { context.Background(), driver, params, nil, + "", ) require.NoError(t, err) require.NotEmpty(t, warnings) @@ -101,6 +103,7 @@ func TestDictTransformer_Transform_validation_error(t *testing.T) { context.Background(), driver, params, nil, + "", ) require.NoError(t, err) require.NotEmpty(t, warnings) @@ -121,6 +124,7 @@ func TestDictTransformer_Transform_error_not_matched(t *testing.T) { context.Background(), driver, params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) @@ -149,6 +153,7 @@ func TestDictTransformer_Transform_use_default(t *testing.T) { context.Background(), driver, params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) @@ -180,6 +185,7 @@ func TestDictTransformer_Transform_with_int_values(t *testing.T) { context.Background(), driver, params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) diff --git a/internal/db/postgres/transformers/email.go b/internal/db/postgres/transformers/email.go index 3ec79f3c..815dcb81 100644 --- a/internal/db/postgres/transformers/email.go +++ b/internal/db/postgres/transformers/email.go @@ -92,7 +92,7 @@ type EmailTransformer struct { originalDomain []byte randomBytesBuf []byte hexEncodedRandomBytesBuf []byte - rrctx *RoRecordContext + rctx *toolkit.RecordContext } func NewEmailTransformer(ctx context.Context, driver *toolkit.Driver, parameters map[string]toolkit.Parameterizer) (utils.Transformer, toolkit.ValidationWarnings, error) { @@ -142,7 +142,7 @@ func NewEmailTransformer(ctx context.Context, driver *toolkit.Driver, parameters return nil, nil, fmt.Errorf(`unable to scan "domain_part_template" param: %w`, err) } - rrctx := NewRoRecordContext() + rrctx := toolkit.NewRecordContext() funcMap := toolkit.FuncMap() if localPartTemplate != "" || domainTemplate != "" { for _, c := range driver.Table.Columns { @@ -208,8 +208,8 @@ func NewEmailTransformer(ctx context.Context, driver *toolkit.Driver, parameters domainTemplate: domainTmpl, validate: validate, buf: bytes.NewBuffer(nil), - hexEncodedRandomBytesBuf: make([]byte, hex.EncodedLen(maxLength)), - rrctx: rrctx, + hexEncodedRandomBytesBuf: make([]byte, hex.EncodedLen(emailTransformerGeneratorSize)), + rctx: rrctx, }, nil, nil } @@ -264,7 +264,7 @@ func (rit *EmailTransformer) setupTemplateContext(originalEmail []byte, r *toolk if rit.localPartTemplate == nil && rit.domainTemplate == nil && !rit.keepOriginalDomain { return nil } - rit.rrctx.setRecord(r) + rit.rctx.SetRecord(r) originalLocalPart, originalDomain, err := EmailParse(originalEmail) if err != nil { diff --git a/internal/db/postgres/transformers/hash_test.go b/internal/db/postgres/transformers/hash_test.go index 15ed8a71..f3c6151c 100644 --- a/internal/db/postgres/transformers/hash_test.go +++ b/internal/db/postgres/transformers/hash_test.go @@ -77,6 +77,7 @@ func TestHashTransformer_Transform_all_functions(t *testing.T) { context.Background(), driver, tt.params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) @@ -154,6 +155,7 @@ func TestHashTransformer_Transform_length_truncation(t *testing.T) { context.Background(), driver, params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) @@ -184,6 +186,7 @@ func TestHashTransformer_Transform_multiple_iterations(t *testing.T) { context.Background(), driver, params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) diff --git a/internal/db/postgres/transformers/json_context.go b/internal/db/postgres/transformers/json_context.go index 2e36596f..f26862c3 100644 --- a/internal/db/postgres/transformers/json_context.go +++ b/internal/db/postgres/transformers/json_context.go @@ -47,7 +47,7 @@ func (jc *JsonContext) GetColumnValue(name string) (any, error) { } func (jc *JsonContext) GetColumnRawValue(name string) (any, error) { - return jc.rc.GetRawColumnValue(name) + return jc.rc.GetColumnRawValue(name) } func (jc *JsonContext) EncodeValueByColumn(name string, v any) (any, error) { diff --git a/internal/db/postgres/transformers/json_test.go b/internal/db/postgres/transformers/json_test.go index cef9bd33..7a23b10b 100644 --- a/internal/db/postgres/transformers/json_test.go +++ b/internal/db/postgres/transformers/json_test.go @@ -44,6 +44,7 @@ func TestJsonTransformer_Transform(t *testing.T) { ]`), }, nil, + "", ) require.NoError(t, err) assert.Empty(t, warnings) @@ -86,6 +87,7 @@ func TestJsonTransformer_Transform_with_template(t *testing.T) { "operations": opsData, }, nil, + "", ) require.NoError(t, err) assert.Empty(t, warnings) @@ -130,6 +132,7 @@ func TestJsonTransformer_Transform_null(t *testing.T) { "keep_null": toolkit.ParamsValue("false"), }, nil, + "", ) require.NoError(t, err) assert.Empty(t, warnings) diff --git a/internal/db/postgres/transformers/masking_test.go b/internal/db/postgres/transformers/masking_test.go index 56b17b21..89d6673a 100644 --- a/internal/db/postgres/transformers/masking_test.go +++ b/internal/db/postgres/transformers/masking_test.go @@ -79,6 +79,7 @@ func TestMaskingTransformer_Transform(t *testing.T) { context.Background(), driver, tt.params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) @@ -109,6 +110,7 @@ func TestMaskingTransformer_type_validation(t *testing.T) { "type": toolkit.ParamsValue("unknown"), }, nil, + "", ) require.NoError(t, err) assert.Len(t, warnings, 1) diff --git a/internal/db/postgres/transformers/noise_date_test.go b/internal/db/postgres/transformers/noise_date_test.go index af84f361..35cd2ecc 100644 --- a/internal/db/postgres/transformers/noise_date_test.go +++ b/internal/db/postgres/transformers/noise_date_test.go @@ -118,6 +118,7 @@ func TestNoiseDateTransformer_Transform(t *testing.T) { context.Background(), driver, tt.params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) diff --git a/internal/db/postgres/transformers/noise_float_test.go b/internal/db/postgres/transformers/noise_float_test.go index 9024d1a2..1c705c9a 100644 --- a/internal/db/postgres/transformers/noise_float_test.go +++ b/internal/db/postgres/transformers/noise_float_test.go @@ -131,6 +131,7 @@ func TestNoiseFloatTransformer_Transform(t *testing.T) { driver, tt.params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) diff --git a/internal/db/postgres/transformers/noise_int_test.go b/internal/db/postgres/transformers/noise_int_test.go index 012ffadd..0a6adc6b 100644 --- a/internal/db/postgres/transformers/noise_int_test.go +++ b/internal/db/postgres/transformers/noise_int_test.go @@ -104,6 +104,7 @@ func TestNoiseIntTransformer_Transform(t *testing.T) { context.Background(), driver, tt.params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) diff --git a/internal/db/postgres/transformers/noise_numeric_test.go b/internal/db/postgres/transformers/noise_numeric_test.go index 61600277..5daf460f 100644 --- a/internal/db/postgres/transformers/noise_numeric_test.go +++ b/internal/db/postgres/transformers/noise_numeric_test.go @@ -121,6 +121,7 @@ func TestNoiseNumericTransformer_Transform(t *testing.T) { driver, tt.params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) diff --git a/internal/db/postgres/transformers/random_bool_test.go b/internal/db/postgres/transformers/random_bool_test.go index b76849fa..22463753 100644 --- a/internal/db/postgres/transformers/random_bool_test.go +++ b/internal/db/postgres/transformers/random_bool_test.go @@ -67,6 +67,7 @@ func TestRandomBoolTransformer_Transform(t *testing.T) { driver, tt.params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) diff --git a/internal/db/postgres/transformers/random_choice_test.go b/internal/db/postgres/transformers/random_choice_test.go index cc4bd721..77b9b5fb 100644 --- a/internal/db/postgres/transformers/random_choice_test.go +++ b/internal/db/postgres/transformers/random_choice_test.go @@ -26,6 +26,7 @@ func TestRandomChoiceTransformer_Transform_with_fail(t *testing.T) { context.Background(), driver, params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) @@ -57,6 +58,7 @@ func TestRandomChoiceTransformer_Transform_validation_error(t *testing.T) { context.Background(), driver, params, nil, + "", ) require.NoError(t, err) require.NotEmpty(t, warnings) @@ -78,6 +80,7 @@ func TestRandomChoiceTransformer_Transform_json(t *testing.T) { context.Background(), driver, params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) diff --git a/internal/db/postgres/transformers/random_date_test.go b/internal/db/postgres/transformers/random_date_test.go index 91ef1297..5409d264 100644 --- a/internal/db/postgres/transformers/random_date_test.go +++ b/internal/db/postgres/transformers/random_date_test.go @@ -116,6 +116,7 @@ func TestTimestampTransformer_Transform(t *testing.T) { context.Background(), driver, tt.params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) diff --git a/internal/db/postgres/transformers/random_float_test.go b/internal/db/postgres/transformers/random_float_test.go index dfd7cb7a..cc6e909a 100644 --- a/internal/db/postgres/transformers/random_float_test.go +++ b/internal/db/postgres/transformers/random_float_test.go @@ -159,6 +159,7 @@ func TestRandomFloatTransformer_Transform(t *testing.T) { driver, tt.params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) diff --git a/internal/db/postgres/transformers/random_int_test.go b/internal/db/postgres/transformers/random_int_test.go index e009c84e..5fe12ef3 100644 --- a/internal/db/postgres/transformers/random_int_test.go +++ b/internal/db/postgres/transformers/random_int_test.go @@ -131,6 +131,7 @@ func TestRandomIntTransformer_Transform_random_static(t *testing.T) { driver, tt.params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) @@ -205,6 +206,7 @@ func TestRandomIntTransformer_Transform_random_dynamic(t *testing.T) { driver, tt.params, tt.dynamicParams, + "", ) require.NoError(t, err) require.Empty(t, warnings) @@ -284,6 +286,7 @@ func TestRandomIntTransformer_Transform_deterministic_dynamic(t *testing.T) { driver, tt.params, tt.dynamicParams, + "", ) require.NoError(t, err) require.Empty(t, warnings) diff --git a/internal/db/postgres/transformers/random_ip_test.go b/internal/db/postgres/transformers/random_ip_test.go index 4ef4c89c..033c1327 100644 --- a/internal/db/postgres/transformers/random_ip_test.go +++ b/internal/db/postgres/transformers/random_ip_test.go @@ -69,6 +69,7 @@ func TestRandomIpTransformer_Transform_random_dynamic(t *testing.T) { driver, tt.params, tt.dynamicParams, + "", ) require.NoError(t, err) require.Empty(t, warnings) diff --git a/internal/db/postgres/transformers/random_mac_test.go b/internal/db/postgres/transformers/random_mac_test.go index 30dbae0d..a624c772 100644 --- a/internal/db/postgres/transformers/random_mac_test.go +++ b/internal/db/postgres/transformers/random_mac_test.go @@ -131,6 +131,7 @@ func TestRandomMacTransformer_Transform_random(t *testing.T) { driver, tt.params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) diff --git a/internal/db/postgres/transformers/random_numeric_test.go b/internal/db/postgres/transformers/random_numeric_test.go index 56d87d19..9f763bd6 100644 --- a/internal/db/postgres/transformers/random_numeric_test.go +++ b/internal/db/postgres/transformers/random_numeric_test.go @@ -95,6 +95,7 @@ func TestBigIntTransformer_Transform_random_static(t *testing.T) { driver, tt.params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) @@ -175,6 +176,7 @@ func TestBigIntTransformer_Transform_random_dynamic(t *testing.T) { driver, tt.params, tt.dynamicParams, + "", ) require.NoError(t, err) require.Empty(t, warnings) @@ -254,6 +256,7 @@ func TestBigIntTransformer_Transform_deterministic_dynamic(t *testing.T) { driver, tt.params, tt.dynamicParams, + "", ) require.NoError(t, err) require.Empty(t, warnings) diff --git a/internal/db/postgres/transformers/random_person_test.go b/internal/db/postgres/transformers/random_person_test.go index ac0b69a8..cf40040d 100644 --- a/internal/db/postgres/transformers/random_person_test.go +++ b/internal/db/postgres/transformers/random_person_test.go @@ -6,11 +6,12 @@ import ( "strings" "testing" + "github.com/rs/zerolog/log" + "github.com/stretchr/testify/require" + "github.com/greenmaskio/greenmask/internal/db/postgres/transformers/utils" "github.com/greenmaskio/greenmask/internal/generators/transformers" "github.com/greenmaskio/greenmask/pkg/toolkit" - "github.com/rs/zerolog/log" - "github.com/stretchr/testify/require" ) func TestRandomPersonTransformer_Transform_static_fullname(t *testing.T) { @@ -32,6 +33,7 @@ func TestRandomPersonTransformer_Transform_static_fullname(t *testing.T) { driver, params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) @@ -69,6 +71,7 @@ func TestRandomPersonTransformer_Transform_static_firstname(t *testing.T) { driver, params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) @@ -106,6 +109,7 @@ func TestRandomPersonTransformer_Transform_static_lastname(t *testing.T) { driver, params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) @@ -150,6 +154,7 @@ func TestRandomPersonTransformer_Transform_static_nullable(t *testing.T) { driver, params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) diff --git a/internal/db/postgres/transformers/random_string_test.go b/internal/db/postgres/transformers/random_string_test.go index 8db5fe04..1efd36da 100644 --- a/internal/db/postgres/transformers/random_string_test.go +++ b/internal/db/postgres/transformers/random_string_test.go @@ -91,6 +91,7 @@ func TestRandomStringTransformer_Transform(t *testing.T) { driver, tt.params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) diff --git a/internal/db/postgres/transformers/random_unix_timestamp_test.go b/internal/db/postgres/transformers/random_unix_timestamp_test.go index 83821476..44da8821 100644 --- a/internal/db/postgres/transformers/random_unix_timestamp_test.go +++ b/internal/db/postgres/transformers/random_unix_timestamp_test.go @@ -125,6 +125,7 @@ func TestUnixTimestampTransformer_Transform__positive_cases__static(t *testing.T context.Background(), driver, tt.params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) @@ -200,6 +201,7 @@ func TestUnixTimestampTransformer_Transform_null_cases(t *testing.T) { context.Background(), driver, tt.params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) @@ -276,6 +278,7 @@ func TestUnixTimestampTransformer_Transform_dynamic(t *testing.T) { driver, tt.params, tt.dynamicParams, + "", ) require.NoError(t, err) require.Empty(t, warnings) diff --git a/internal/db/postgres/transformers/random_uuid_test.go b/internal/db/postgres/transformers/random_uuid_test.go index 3a2358a5..4bf31f66 100644 --- a/internal/db/postgres/transformers/random_uuid_test.go +++ b/internal/db/postgres/transformers/random_uuid_test.go @@ -76,6 +76,7 @@ func TestUuidTransformer_Transform_uuid_type(t *testing.T) { driver, tt.params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) diff --git a/internal/db/postgres/transformers/real_address_test.go b/internal/db/postgres/transformers/real_address_test.go index 44a1dbff..d40e1372 100644 --- a/internal/db/postgres/transformers/real_address_test.go +++ b/internal/db/postgres/transformers/real_address_test.go @@ -44,6 +44,7 @@ func TestRealAddressTransformer_Transform(t *testing.T) { "columns": rawData, }, nil, + "", ) require.NoError(t, err) @@ -77,6 +78,7 @@ func TestMakeNewFakeTransformerFunction_parsing_error(t *testing.T) { "columns": rawData, }, nil, + "", ) require.Len(t, warnings, 1) require.Equal(t, "error parsing template", warnings[0].Msg) @@ -102,6 +104,7 @@ func TestMakeNewFakeTransformerFunction_validation_error(t *testing.T) { "columns": rawData, }, nil, + "", ) require.Len(t, warnings, 1) require.Equal(t, "error validating template", warnings[0].Msg) diff --git a/internal/db/postgres/transformers/regexp_replace_test.go b/internal/db/postgres/transformers/regexp_replace_test.go index a26f096c..fe27e797 100644 --- a/internal/db/postgres/transformers/regexp_replace_test.go +++ b/internal/db/postgres/transformers/regexp_replace_test.go @@ -52,6 +52,7 @@ func TestRegexpReplaceTransformer_Transform2(t *testing.T) { driver, tt.params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) diff --git a/internal/db/postgres/transformers/replace_test.go b/internal/db/postgres/transformers/replace_test.go index 996ec672..1d31675d 100644 --- a/internal/db/postgres/transformers/replace_test.go +++ b/internal/db/postgres/transformers/replace_test.go @@ -87,6 +87,7 @@ func TestReplaceTransformer_Transform(t *testing.T) { driver, tt.params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) @@ -146,6 +147,7 @@ func TestReplaceTransformer_Transform_with_raw_value(t *testing.T) { driver, tt.params, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) @@ -185,6 +187,7 @@ func TestReplaceTransformer_Transform_with_validation_error(t *testing.T) { driver, params, nil, + "", ) require.NoError(t, err) assert.NotEmpty(t, warnings) diff --git a/internal/db/postgres/transformers/ro_record_context.go b/internal/db/postgres/transformers/ro_record_context.go deleted file mode 100644 index 671adac8..00000000 --- a/internal/db/postgres/transformers/ro_record_context.go +++ /dev/null @@ -1,47 +0,0 @@ -package transformers - -import ( - "github.com/greenmaskio/greenmask/pkg/toolkit" -) - -type RoRecordContext struct { - rc *toolkit.RecordContext -} - -func NewRoRecordContext() *RoRecordContext { - return &RoRecordContext{ - rc: &toolkit.RecordContext{}, - } -} - -func (cc *RoRecordContext) clean() { - cc.rc.Clean() -} - -func (cc *RoRecordContext) setRecord(r *toolkit.Record) { - cc.rc.SetRecord(r) -} - -func (cc *RoRecordContext) GetColumnValue(name string) (any, error) { - return cc.rc.GetColumnValue(name) -} - -func (cc *RoRecordContext) GetColumnRawValue(name string) (any, error) { - return cc.rc.GetRawColumnValue(name) -} - -func (cc *RoRecordContext) EncodeValueByColumn(name string, v any) (any, error) { - return cc.rc.EncodeValueByColumn(name, v) -} - -func (cc *RoRecordContext) DecodeValueByColumn(name string, v any) (any, error) { - return cc.rc.DecodeValueByColumn(name, v) -} - -func (cc *RoRecordContext) EncodeValueByType(name string, v any) (any, error) { - return cc.rc.EncodeValueByType(name, v) -} - -func (cc *RoRecordContext) DecodeValueByType(name string, v any) (any, error) { - return cc.rc.DecodeValueByType(name, v) -} diff --git a/internal/db/postgres/transformers/set_null_test.go b/internal/db/postgres/transformers/set_null_test.go index 7f10cdaf..dae225ae 100644 --- a/internal/db/postgres/transformers/set_null_test.go +++ b/internal/db/postgres/transformers/set_null_test.go @@ -37,6 +37,7 @@ func TestSetNullTransformer_Transform(t *testing.T) { "column": toolkit.ParamsValue(columnName), }, nil, + "", ) require.NoError(t, err) assert.Empty(t, warnings) diff --git a/internal/db/postgres/transformers/template_record_test.go b/internal/db/postgres/transformers/template_record_test.go index 4a5424b6..e8b8866a 100644 --- a/internal/db/postgres/transformers/template_record_test.go +++ b/internal/db/postgres/transformers/template_record_test.go @@ -60,6 +60,7 @@ func TestTemplateRecordTransformer_Transform_date(t *testing.T) { "template": toolkit.ParamsValue(template), }, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) @@ -85,7 +86,7 @@ func TestTemplateRecordTransformer_Transform_date(t *testing.T) { func TestTemplateRecordTransformer_Transform_json(t *testing.T) { var columnName = "doc" var template = ` - {{ $val := .GetRawColumnValue "doc" }} + {{ $val := .GetColumnRawValue "doc" }} {{ jsonSet "name" "hello" $val | jsonValidate | .SetColumnValue "doc" }} ` @@ -109,6 +110,7 @@ func TestTemplateRecordTransformer_Transform_json(t *testing.T) { "template": toolkit.ParamsValue(template), }, nil, + "", ) require.NoError(t, err) require.Empty(t, warnings) diff --git a/internal/db/postgres/transformers/template_test.go b/internal/db/postgres/transformers/template_test.go index 3da253cc..c1526288 100644 --- a/internal/db/postgres/transformers/template_test.go +++ b/internal/db/postgres/transformers/template_test.go @@ -48,6 +48,7 @@ func TestTemplateTransformer_Transform_int(t *testing.T) { "template": toolkit.ParamsValue(template), }, nil, + "", ) require.NoError(t, err) assert.Empty(t, warnings) @@ -86,6 +87,7 @@ func TestTemplateTransformer_Transform_timestamp(t *testing.T) { "template": toolkit.ParamsValue(template), }, nil, + "", ) require.NoError(t, err) assert.Empty(t, warnings) diff --git a/internal/db/postgres/transformers/utils/definition.go b/internal/db/postgres/transformers/utils/definition.go index c09ef43b..62556c45 100644 --- a/internal/db/postgres/transformers/utils/definition.go +++ b/internal/db/postgres/transformers/utils/definition.go @@ -52,14 +52,9 @@ func (d *TransformerDefinition) SetSchemaValidator(v SchemaValidationFunc) *Tran return d } -type TransformerContext struct { - Transformer Transformer - StaticParameters map[string]*toolkit.StaticParameter - DynamicParameters map[string]*toolkit.DynamicParameter -} - func (d *TransformerDefinition) Instance( ctx context.Context, driver *toolkit.Driver, rawParams map[string]toolkit.ParamsValue, dynamicParameters map[string]*toolkit.DynamicParamValue, + whenCond string, ) (*TransformerContext, toolkit.ValidationWarnings, error) { // DecodeValue parameters and get the pgcopy of parsed params, parametersWarnings, err := toolkit.InitParameters(driver, d.Parameters, rawParams, dynamicParameters) @@ -103,9 +98,30 @@ func (d *TransformerDefinition) Instance( res = append(res, schemaWarnings...) res = append(res, transformerWarnings...) + meta := map[string]interface{}{ + "TableSchema": driver.Table.Schema, + "TableName": driver.Table.Name, + "Transformer": d.Properties.Name, + } + + when, condWarns := toolkit.NewWhenCond(whenCond, driver, meta) + res = append(res, condWarns...) + return &TransformerContext{ Transformer: t, StaticParameters: staticParams, DynamicParameters: dynamicParams, + When: when, }, res, nil } + +type TransformerContext struct { + Transformer Transformer + StaticParameters map[string]*toolkit.StaticParameter + DynamicParameters map[string]*toolkit.DynamicParameter + When *toolkit.WhenCond +} + +func (tc *TransformerContext) EvaluateWhen(r *toolkit.Record) (bool, error) { + return tc.When.Evaluate(r) +} diff --git a/internal/db/postgres/transformers/utils/definition_test.go b/internal/db/postgres/transformers/utils/definition_test.go index 02eb72e6..f23adfe5 100644 --- a/internal/db/postgres/transformers/utils/definition_test.go +++ b/internal/db/postgres/transformers/utils/definition_test.go @@ -100,7 +100,7 @@ func TestDefinition(t *testing.T) { "replace": []byte("2023-08-27 12:08:11.304895"), } - _, warnings, err := TestTransformerDefinition.Instance(context.Background(), driver, rawParams, nil) + _, warnings, err := TestTransformerDefinition.Instance(context.Background(), driver, rawParams, nil, "") require.NoError(t, err) assert.Empty(t, warnings) } diff --git a/internal/domains/config.go b/internal/domains/config.go index 84d9a5c3..7083e8b3 100644 --- a/internal/domains/config.go +++ b/internal/domains/config.go @@ -144,6 +144,7 @@ type TransformerConfig struct { // this is used only due to https://github.com/spf13/viper/issues/373 MetadataParams map[string]any `mapstructure:"-" yaml:"params,omitempty" json:"params,omitempty"` DynamicParams toolkit.DynamicParameters `mapstructure:"dynamic_params" yaml:"dynamic_params" json:"dynamic_params,omitempty"` + When string `mapstructure:"when" yaml:"when" json:"when,omitempty"` } type Table struct { @@ -154,6 +155,7 @@ type Table struct { Transformers []*TransformerConfig `mapstructure:"transformers" yaml:"transformers" json:"transformers,omitempty"` ColumnsTypeOverride map[string]string `mapstructure:"columns_type_override" yaml:"columns_type_override" json:"columns_type_override,omitempty"` SubsetConds []string `mapstructure:"subset_conds" yaml:"subset_conds" json:"subset_conds,omitempty"` + When string `mapstructure:"when" yaml:"when" json:"when,omitempty"` } // DummyConfig - This is a dummy config to the viper workaround diff --git a/mkdocs.yml b/mkdocs.yml index b8774aa1..d05ae518 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -64,6 +64,7 @@ nav: - Dynamic parameters: built_in_transformers/dynamic_parameters.md - Transformation engines: built_in_transformers/transformation_engines.md - Parameters templating: built_in_transformers/parameters_templating.md + - Transformation conditions: built_in_transformers/transformation_condition.md - Standard transformers: - built_in_transformers/standard_transformers/index.md - Cmd: built_in_transformers/standard_transformers/cmd.md diff --git a/pkg/toolkit/dynamic_parameter.go b/pkg/toolkit/dynamic_parameter.go index 10535f60..99bb2aff 100644 --- a/pkg/toolkit/dynamic_parameter.go +++ b/pkg/toolkit/dynamic_parameter.go @@ -58,7 +58,7 @@ func (dpc *DynamicParameterContext) GetValue() (any, error) { } func (dpc *DynamicParameterContext) GetRawValue() (any, error) { - return dpc.rc.GetRawColumnValue(dpc.column.Name) + return dpc.rc.GetColumnRawValue(dpc.column.Name) } func (dpc *DynamicParameterContext) GetColumnValue(name string) (any, error) { @@ -66,7 +66,7 @@ func (dpc *DynamicParameterContext) GetColumnValue(name string) (any, error) { } func (dpc *DynamicParameterContext) GetColumnRawValue(name string) (any, error) { - return dpc.rc.GetRawColumnValue(name) + return dpc.rc.GetColumnRawValue(name) } func (dpc *DynamicParameterContext) EncodeValue(v any) (any, error) { diff --git a/pkg/toolkit/expr.go b/pkg/toolkit/expr.go new file mode 100644 index 00000000..134c8168 --- /dev/null +++ b/pkg/toolkit/expr.go @@ -0,0 +1,243 @@ +// An expression handler for the toolkit package. It is used to evaluate the when condition of the record. +// Might be used in transformation conditions and other places where the record is used. + +package toolkit + +import ( + "fmt" + + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/ast" + "github.com/expr-lang/expr/vm" + "github.com/rs/zerolog/log" +) + +const ( + recordExprNamespace = "record" + rawRecordExprNamespace = "raw_record" +) + +// WhenCond - A condition that should be evaluated to determine if the record should be processed. +type WhenCond struct { + rc *RecordContext + whenCond *vm.Program + when string + env map[string]any +} + +// NewWhenCond - creates a new WhenCond object. It compiles the when condition and returns the compiled program +// and the record context with the functions for the columns. The functions represent the column names and return the +// column values. If the when condition is empty, the WhenCond object will always return true. +func NewWhenCond(when string, driver *Driver, meta map[string]any) (*WhenCond, ValidationWarnings) { + var ( + rc *RecordContext + whenCond *vm.Program + warnings ValidationWarnings + ) + if when != "" { + whenCond, rc, warnings = compileCond(when, driver, meta) + if warnings.IsFatal() { + return nil, warnings + } + } + env := FuncMap() + env["null"] = NullValue + return &WhenCond{ + rc: rc, + whenCond: whenCond, + when: when, + env: env, + }, nil +} + +// Evaluate - evaluates the when condition. If the when condition is empty, it will always return true. +func (wc *WhenCond) Evaluate(r *Record) (bool, error) { + if wc.whenCond == nil { + return true, nil + } + wc.rc.SetRecord(r) + + output, err := expr.Run(wc.whenCond, wc.env) + if err != nil { + return false, fmt.Errorf("unable to evaluate when condition: %w", err) + } + + cond, ok := output.(bool) + if ok { + return cond, nil + } + + return false, fmt.Errorf("when condition should return boolean, got (%T) and value %+v", cond, cond) +} + +// compileCond compiles the when condition and returns the compiled program and the record context +// with the functions for the columns. The functions represent the column names and return the column values. +// meta - additional meta information for debugging the compilation process +func compileCond(whenCond string, driver *Driver, meta map[string]any) ( + *vm.Program, *RecordContext, ValidationWarnings, +) { + if whenCond == "" { + return nil, nil, nil + } + scope := "table" + if _, ok := meta["Transformer"]; ok { + scope = "transformer" + } + meta["Scope"] = scope + log.Debug(). + Str("WhenCond", whenCond). + Any("Meta", meta). + Msg("found when condition: compiling") + rc, ops := newRecordContext(driver) + ops = append(ops, expr.Patch(newExprPatcher(meta))) + + cond, err := expr.Compile(whenCond, ops...) + if err != nil { + return nil, nil, ValidationWarnings{ + NewValidationWarning(). + SetSeverity(ErrorValidationSeverity). + AddMeta("Error", err.Error()). + SetMsg("unable to compile when condition"), + } + } + + return cond, rc, nil +} + +// newRecordContext creates a new record context and create kind of column descriptors for the record to access the +// column values by the column name. For instance if the column name is "name", the function __name will return +// the value +func newRecordContext(driver *Driver) (*RecordContext, []expr.Option) { + var funcs []expr.Option + rctx := NewRecordContext() + for _, c := range driver.Table.Columns { + + // create a function that returns the column value by the column name. The returned value is encoded using + // pgx driver + typedFunc := expr.Function( + fmt.Sprintf("__%s", c.Name), + func(name string) func(params ...any) (any, error) { + return func(params ...any) (any, error) { + v, err := rctx.GetColumnValue(name) + if err != nil { + return nil, err + } + // convert the value to the appropriate type for expr library + // the expected types must be nil, bool, int, uint, float32, string, array, map + switch vv := v.(type) { + case float32: + return float64(vv), nil + case int64: + return int(vv), nil + case int32: + return int(vv), nil + case int16: + return int(vv), nil + case int8: + return int(vv), nil + case byte: + return int(vv), nil + case uint64: + return uint(vv), nil + case uint32: + return uint(vv), nil + case uint16: + return uint(vv), nil + } + return v, nil + } + }(c.Name), + ) + funcs = append(funcs, typedFunc) + + rawFunc := expr.Function( + fmt.Sprintf("__raw__%s", c.Name), + func(name string) func(params ...any) (any, error) { + return func(params ...any) (any, error) { + return rctx.GetColumnRawValue(name) + } + }(c.Name), + ) + funcs = append(funcs, rawFunc) + } + return rctx, funcs +} + +// exprPatcher - patcher for the expression compiler. It patches the expression tree by some identifiers to +// function calls. For instance is null, is not null, records address +type exprPatcher struct { + meta map[string]any +} + +func newExprPatcher(meta map[string]any) *exprPatcher { + return &exprPatcher{ + meta: meta, + } +} + +func (ep *exprPatcher) Visit(node *ast.Node) { + log.Debug(). + Any("Meta", ep.meta). + Any("Node", node). + Type("NodeType", *node). + Str("NodeFmt", fmt.Sprintf("%+v", *node)). + Msg("debugging expr tree nodes") + if isRecordOp(node) { + patchRecordOp(node) + } +} + +// isRecordOp checks if the node is a record operation +func isRecordOp(node *ast.Node) bool { + mn, ok := (*node).(*ast.MemberNode) + if !ok { + return false + } + owner, ok := (mn.Node).(*ast.IdentifierNode) + if !ok { + return false + } + _, ok = (mn.Property).(*ast.StringNode) + if !ok { + return false + } + return owner.Value == recordExprNamespace || owner.Value == rawRecordExprNamespace +} + +// patchRecordOp patches the record access operation +// 1. record.id -> __id() function call for decoding the column value into type using pgx driver +// 2. raw_record.id -> __raw_id() function call getting a raw value as a string +func patchRecordOp(node *ast.Node) { + mn, ok := (*node).(*ast.MemberNode) + if !ok { + return + } + owner, ok := (mn.Node).(*ast.IdentifierNode) + if !ok { + return + } + attr, ok := (mn.Property).(*ast.StringNode) + if !ok { + return + } + var newOp *ast.CallNode + switch owner.Value { + case recordExprNamespace: + newOp = &ast.CallNode{ + Callee: &ast.IdentifierNode{ + Value: fmt.Sprintf("__%s", attr.Value), + }, + } + + case rawRecordExprNamespace: + newOp = &ast.CallNode{ + Callee: &ast.IdentifierNode{ + Value: fmt.Sprintf("__raw__%s", attr.Value), + }, + } + } + + log.Debug().Any("OriginalNode", node).Any("NewNode", newOp).Msg("patching record operation") + ast.Patch(node, newOp) + +} diff --git a/pkg/toolkit/expt_test.go b/pkg/toolkit/expt_test.go new file mode 100644 index 00000000..c90ed8d0 --- /dev/null +++ b/pkg/toolkit/expt_test.go @@ -0,0 +1,67 @@ +package toolkit + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWhenCond_Evaluate(t *testing.T) { + driver := getDriver() + record := NewRecord(driver) + row := newTestRowDriver([]string{"1", "2023-08-27 00:00:00.000000", testNullSeq, `{"a": 1}`, "123.0"}) + record.SetRow(row) + + type test struct { + name string + when string + expected bool + } + tests := []test{ + { + name: "int value equal", + when: "record.id == 1", + expected: true, + }, + { + name: "raw int value equal", + when: "raw_record.id == \"1\"", + expected: true, + }, + { + name: "is null value check", + when: "record.title == null", + expected: true, + }, + { + name: "test date cmp", + when: "record.created_at > now()", + expected: false, + }, + { + name: "test json cmp and sping func", + when: `raw_record.json_data | jsonGet("a") == 1`, + expected: false, + }, + { + name: "check has array func", + when: `record.id | has([1, 2, 3, 9223372036854775807])`, + expected: true, + }, + { + name: "float cmp", + when: `record.float_data | has([123.0, 1., 10.])`, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + whenCond, warns := NewWhenCond(tt.when, driver, make(map[string]any)) + require.Empty(t, warns) + res, err := whenCond.Evaluate(record) + require.NoError(t, err) + require.Equal(t, tt.expected, res) + }) + } +} diff --git a/pkg/toolkit/record_test.go b/pkg/toolkit/record_test.go index 80ae3f91..a17134a8 100644 --- a/pkg/toolkit/record_test.go +++ b/pkg/toolkit/record_test.go @@ -53,6 +53,22 @@ func getDriver() *Driver { NotNull: true, Length: -1, }, + { + Name: "json_data", + TypeName: "jsonb", + TypeOid: pgtype.JSONBOID, + Num: 4, + NotNull: true, + Length: -1, + }, + { + Name: "float_data", + TypeName: "float4", + TypeOid: pgtype.Float4OID, + Num: 5, + NotNull: true, + Length: 4, + }, }, Constraints: []Constraint{}, } @@ -64,9 +80,7 @@ func getDriver() *Driver { } func TestRecord_ScanAttribute(t *testing.T) { - row := &TestRowDriver{ - row: []string{"1", "2023-08-27 00:00:00.000000"}, - } + row := newTestRowDriver([]string{"1", "2023-08-27 00:00:00.000000", ""}) driver := getDriver() r := NewRecord(driver) r.SetRow(row) @@ -78,9 +92,7 @@ func TestRecord_ScanAttribute(t *testing.T) { } func TestRecord_GetAttribute_date(t *testing.T) { - row := &TestRowDriver{ - row: []string{"1", "2023-08-27 00:00:00.000000", "1234"}, - } + row := newTestRowDriver([]string{"1", "2023-08-27 00:00:00.000000", ""}) driver := getDriver() r := NewRecord(driver) r.SetRow(row) @@ -92,9 +104,7 @@ func TestRecord_GetAttribute_date(t *testing.T) { } func TestRecord_GetAttribute_text(t *testing.T) { - row := &TestRowDriver{ - row: []string{"1", "2023-08-27 00:00:00.000000", "1234"}, - } + row := newTestRowDriver([]string{"1", "2023-08-27 00:00:00.000000", "1234", ""}) driver := getDriver() r := NewRecord(driver) r.SetRow(row) @@ -105,32 +115,29 @@ func TestRecord_GetAttribute_text(t *testing.T) { assert.Equal(t, expected.Value, res.Value) } -func TestRecord_GetTuple(t *testing.T) { - expected := Tuple{ - "id": NewValue(int16(1), false), - "created_at": NewValue(time.Date(2023, time.August, 27, 0, 0, 0, 0, time.UTC), false), - "title": NewValue(nil, true), - } - row := &TestRowDriver{ - row: []string{"1", "2023-08-27 00:00:00.000000", testNullSeq}, - } - driver := getDriver() - r := NewRecord(driver) - r.SetRow(row) - res, err := r.GetTuple() - require.NoError(t, err) - for name := range expected { - assert.Equalf(t, expected[name].IsNull, res[name].IsNull, "wrong IsNull value %s", name) - assert.Equalf(t, expected[name].Value, res[name].Value, "wrong Value %s", name) - } - -} +//func TestRecord_GetTuple(t *testing.T) { +// expected := Tuple{ +// "id": NewValue(int16(1), false), +// "created_at": NewValue(time.Date(2023, time.August, 27, 0, 0, 0, 0, time.UTC), false), +// "title": NewValue(nil, true), +// } +// row := &TestRowDriver{ +// row: []string{"1", "2023-08-27 00:00:00.000000", testNullSeq, "", ""}, +// } +// driver := getDriver() +// r := NewRecord(driver) +// r.SetRow(row) +// res, err := r.GetTuple() +// require.NoError(t, err) +// for name := range expected { +// assert.Equalf(t, expected[name].IsNull, res[name].IsNull, "wrong IsNull value %s", name) +// assert.Equalf(t, expected[name].Value, res[name].Value, "wrong Value %s", name) +// } +//} func TestRecord_Encode(t *testing.T) { - row := &TestRowDriver{ - row: []string{"1", "2023-08-27 00:00:00.000001", "test"}, - } - expected := []byte("2\t2023-08-29 00:00:00.000002\t\\N") + row := newTestRowDriver([]string{"1", "2023-08-27 00:00:00.000001", "test", "", ""}) + expected := []byte("2\t2023-08-29 00:00:00.000002\t\\N\t\t") driver := getDriver() r := NewRecord(driver) r.SetRow(row) diff --git a/pkg/toolkit/template_record_context.go b/pkg/toolkit/template_record_context.go index e3a0da06..10daa7f7 100644 --- a/pkg/toolkit/template_record_context.go +++ b/pkg/toolkit/template_record_context.go @@ -61,7 +61,7 @@ func (rc *RecordContext) GetColumnValue(name string) (any, error) { return v.Value, nil } -func (rc *RecordContext) GetRawColumnValue(name string) (any, error) { +func (rc *RecordContext) GetColumnRawValue(name string) (any, error) { v, err := rc.record.GetRawColumnValueByName(name) if err != nil { return nil, err @@ -87,7 +87,7 @@ func (rc *RecordContext) SetColumnValue(name string, v any) (bool, error) { return true, nil } -func (rc *RecordContext) SetRawColumnValue(name string, v any) (bool, error) { +func (rc *RecordContext) SetColumnRawValue(name string, v any) (bool, error) { var val *RawValue switch vv := v.(type) { case NullType: diff --git a/pkg/toolkit/testutils_test.go b/pkg/toolkit/testutils.go similarity index 94% rename from pkg/toolkit/testutils_test.go rename to pkg/toolkit/testutils.go index d74441ff..31a26bae 100644 --- a/pkg/toolkit/testutils_test.go +++ b/pkg/toolkit/testutils.go @@ -21,6 +21,10 @@ type TestRowDriver struct { row []string } +func newTestRowDriver(row []string) *TestRowDriver { + return &TestRowDriver{row: row} +} + func (trd *TestRowDriver) GetColumn(idx int) (*RawValue, error) { val := trd.row[idx] if val == testNullSeq { diff --git a/pkg/toolkit/testutils/testutils.go b/pkg/toolkit/testutils/testutils.go new file mode 100644 index 00000000..d06d624a --- /dev/null +++ b/pkg/toolkit/testutils/testutils.go @@ -0,0 +1,68 @@ +// Copyright 2023 Greenmask +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutils + +import "github.com/greenmaskio/greenmask/pkg/toolkit" + +var NullSeq = "\\N" +var Delim byte = '\t' + +type TestRowDriver struct { + row []string +} + +func NewTestRowDriver(row []string) *TestRowDriver { + return &TestRowDriver{row: row} +} + +func (trd *TestRowDriver) GetColumn(idx int) (*toolkit.RawValue, error) { + val := trd.row[idx] + if val == NullSeq { + return toolkit.NewRawValue(nil, true), nil + } + return toolkit.NewRawValue([]byte(val), false), nil +} + +func (trd *TestRowDriver) SetColumn(idx int, v *toolkit.RawValue) error { + if v.IsNull { + trd.row[idx] = NullSeq + } else { + trd.row[idx] = string(v.Data) + } + return nil +} + +func (trd *TestRowDriver) Encode() ([]byte, error) { + var res []byte + for idx, v := range trd.row { + res = append(res, []byte(v)...) + if idx != len(trd.row)-1 { + res = append(res, Delim) + } + } + return res, nil +} + +func (trd *TestRowDriver) Decode([]byte) error { + panic("is not implemented") +} + +func (trd *TestRowDriver) Length() int { + return len(trd.row) +} + +func (trd *TestRowDriver) Clean() { + +} diff --git a/pkg/toolkit/validation_warning.go b/pkg/toolkit/validation_warning.go index e35db8e9..d48cc547 100644 --- a/pkg/toolkit/validation_warning.go +++ b/pkg/toolkit/validation_warning.go @@ -19,6 +19,8 @@ import ( "encoding/hex" "fmt" "slices" + + "github.com/rs/zerolog/log" ) const ( @@ -28,15 +30,6 @@ const ( DebugValidationSeverity = "debug" ) -// deprecated -type Trace struct { - SchemaName string `json:"schemaName,omitempty"` - TableName string `json:"tableName,omitempty"` - TransformerName string `json:"transformerName,omitempty"` - ParameterName string `json:"parameterName,omitempty"` - Msg string `json:"msg,omitempty"` -} - type ValidationWarnings []*ValidationWarning func (re ValidationWarnings) IsFatal() bool { @@ -48,7 +41,6 @@ func (re ValidationWarnings) IsFatal() bool { type ValidationWarning struct { Msg string `json:"msg,omitempty"` Severity string `json:"severity,omitempty"` - Trace *Trace `json:"trace,omitempty"` Meta map[string]any `json:"meta,omitempty"` Hash string `json:"hash"` } @@ -80,11 +72,6 @@ func (re *ValidationWarning) AddMeta(key string, value any) *ValidationWarning { return re } -func (re *ValidationWarning) SetTrace(value *Trace) *ValidationWarning { - re.Trace = value - return re -} - func (re *ValidationWarning) MakeHash() { var meta string keys := make([]string, 0, len(re.Meta)) @@ -102,3 +89,28 @@ func (re *ValidationWarning) MakeHash() { hash := md5.Sum([]byte(signature)) re.Hash = hex.EncodeToString(hash[:]) } + +func PrintValidationWarnings(warns ValidationWarnings, resolvedWarnings []string, printAll bool) error { + // TODO: Implement warnings hook, such as logging and HTTP sender + for _, w := range warns { + w.MakeHash() + if idx := slices.Index(resolvedWarnings, w.Hash); idx != -1 { + log.Debug().Str("hash", w.Hash).Msg("resolved warning has been excluded") + if w.Severity == ErrorValidationSeverity { + return fmt.Errorf("warning with hash %s cannot be excluded because it is an error", w.Hash) + } + continue + } + + if w.Severity == ErrorValidationSeverity { + // The warnings with error severity must be printed anyway + log.Error().Any("ValidationWarning", w).Msg("") + } else { + // Print warnings with severity level lower than ErrorValidationSeverity only if requested + if printAll { + log.Warn().Any("ValidationWarning", w).Msg("") + } + } + } + return nil +} diff --git a/pkg/toolkit/test/test.go b/tests/external_transformer/test.go similarity index 99% rename from pkg/toolkit/test/test.go rename to tests/external_transformer/test.go index 6929e339..dfed5ba4 100644 --- a/pkg/toolkit/test/test.go +++ b/tests/external_transformer/test.go @@ -1,4 +1,5 @@ // Copyright 2023 Greenmask + // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. From f3553eeca768d55bbfed2560afdd90812e1e263d Mon Sep 17 00:00:00 2001 From: Vadim Voitenko Date: Sun, 27 Oct 2024 21:28:51 +0200 Subject: [PATCH 2/2] fix: added raiseAnErrorIfSysIs32AndDriverReturns64 that raises panic for 32 bit sys due to limitation of go expr --- pkg/toolkit/expr.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pkg/toolkit/expr.go b/pkg/toolkit/expr.go index 134c8168..0e3db56b 100644 --- a/pkg/toolkit/expr.go +++ b/pkg/toolkit/expr.go @@ -5,6 +5,7 @@ package toolkit import ( "fmt" + "unsafe" "github.com/expr-lang/expr" "github.com/expr-lang/expr/ast" @@ -108,6 +109,7 @@ func compileCond(whenCond string, driver *Driver, meta map[string]any) ( // column values by the column name. For instance if the column name is "name", the function __name will return // the value func newRecordContext(driver *Driver) (*RecordContext, []expr.Option) { + intSize := unsafe.Sizeof(int(0)) * 8 var funcs []expr.Option rctx := NewRecordContext() for _, c := range driver.Table.Columns { @@ -128,6 +130,7 @@ func newRecordContext(driver *Driver) (*RecordContext, []expr.Option) { case float32: return float64(vv), nil case int64: + raiseAnErrorIfSysIs32AndDriverReturns64(intSize) return int(vv), nil case int32: return int(vv), nil @@ -241,3 +244,12 @@ func patchRecordOp(node *ast.Node) { ast.Patch(node, newOp) } + +// raiseAnErrorIfSysIs32AndDriverReturns64 - raises an error if the system is 32 bit and the driver returns 64 bit +// values. In 32-bit system int type is 32 bit but int 64 is 64 bit. In this case the int8 postgresql type cannot be +// handled using int type in go because it cast to int32 and loses the data. This is limitation of the go expr library +func raiseAnErrorIfSysIs32AndDriverReturns64(sysBytes uintptr) { + if sysBytes == 32 { + panic("go expr and pgx driver are not compatible to handle int8 postgresql type using int type in go") + } +}