From 8bbdead0fbc1051066228d38c3cdbdc878f283e2 Mon Sep 17 00:00:00 2001 From: Vadim Voitenko Date: Sat, 7 Dec 2024 17:50:18 +0200 Subject: [PATCH] feat: Test coverage for disable triggers in TableRestorer --- internal/db/postgres/restorers/base_test.go | 33 +++-- internal/db/postgres/restorers/table.go | 5 +- internal/db/postgres/restorers/table_test.go | 129 +++++++++++++++++++ 3 files changed, 157 insertions(+), 10 deletions(-) create mode 100644 internal/db/postgres/restorers/table_test.go diff --git a/internal/db/postgres/restorers/base_test.go b/internal/db/postgres/restorers/base_test.go index 6e89f82d..c2141b58 100644 --- a/internal/db/postgres/restorers/base_test.go +++ b/internal/db/postgres/restorers/base_test.go @@ -19,6 +19,9 @@ const ( migrationUp = ` CREATE USER non_super_user PASSWORD '1234' NOINHERIT; GRANT testuser TO non_super_user; +GRANT SELECT, INSERT ON ALL TABLES IN SCHEMA public TO non_super_user; +ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT, INSERT ON TABLES TO non_super_user; +GRANT INSERT ON ALL TABLES IN SCHEMA public TO non_super_user; -- Create the 'users' table CREATE TABLE users ( @@ -32,6 +35,7 @@ CREATE TABLE users ( CREATE TABLE orders ( id SERIAL PRIMARY KEY, user_id INT NOT NULL, + raise_error TEXT, order_amount NUMERIC(10, 2) NOT NULL, order_date DATE DEFAULT CURRENT_DATE, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, @@ -42,6 +46,9 @@ CREATE TABLE orders ( CREATE OR REPLACE FUNCTION set_order_date() RETURNS TRIGGER AS $$ BEGIN + If NEW.raise_error != '' THEN + RAISE EXCEPTION '%', NEW.raise_error; + END IF; IF NEW.order_date IS NULL THEN NEW.order_date = CURRENT_DATE; END IF; @@ -66,6 +73,14 @@ INSERT INTO orders (user_id, order_amount) VALUES (2, 200.75); ` migrationDown = ` +REVOKE ALL ON SCHEMA public FROM non_super_user; +REVOKE ALL ON ALL TABLES IN SCHEMA public FROM non_super_user; +REVOKE ALL ON ALL SEQUENCES IN SCHEMA public FROM non_super_user; +REVOKE ALL ON ALL FUNCTIONS IN SCHEMA public FROM non_super_user; +REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM non_super_user; +ALTER DEFAULT PRIVILEGES IN SCHEMA public REVOKE SELECT, INSERT ON TABLES FROM non_super_user; +REVOKE USAGE ON SCHEMA public FROM non_super_user; +REVOKE testuser FROM non_super_user; DROP USER non_super_user; DROP TRIGGER IF EXISTS trg_set_order_date ON orders; DROP FUNCTION IF EXISTS set_order_date; @@ -83,9 +98,9 @@ func (r *readCloserMock) Close() error { } type restoresSuite struct { + testutils.PgContainerSuite nonSuperUserPassword string nonSuperUser string - testutils.PgContainerSuite } func (s *restoresSuite) SetupSuite() { @@ -192,18 +207,18 @@ func (s *restoresSuite) Test_restoreBase_enableTriggers() { Namespace: &schemaName, Tag: &tableName, }, nil, opt) - cxt := context.Background() - conn, err := s.GetConnectionWithUser(cxt, s.nonSuperUser, s.nonSuperUserPassword) - defer conn.Close(cxt) + ctx := context.Background() + conn, err := s.GetConnectionWithUser(ctx, s.nonSuperUser, s.nonSuperUserPassword) + defer conn.Close(ctx) s.Require().NoError(err) - tx, err := conn.Begin(cxt) + tx, err := conn.Begin(ctx) s.Require().NoError(err) - err = rb.disableTriggers(cxt, tx) + err = rb.disableTriggers(ctx, tx) s.Require().NoError(err) expectedUser := s.nonSuperUser var actualUser string - r := tx.QueryRow(cxt, "SELECT current_user") + r := tx.QueryRow(ctx, "SELECT current_user") err = r.Scan(&actualUser) s.Require().NoError(err) s.Assert().Equal(expectedUser, actualUser) @@ -221,7 +236,7 @@ WHERE n.nspname = $1 AND c.relname = $2 AND t.tgname = ANY($3); ` rows, err := conn.Query( - cxt, checkDisabledTriggerSql, schemaName, tableName, []string{"trg_set_order_date"}, + ctx, checkDisabledTriggerSql, schemaName, tableName, []string{"trg_set_order_date"}, ) s.Require().NoError(err) defer rows.Close() @@ -248,7 +263,7 @@ WHERE n.nspname = $1 AND c.relname = $2 s.Assert().Equal(expected.tgenabled, triggers[i].tgenabled) } - s.NoError(tx.Rollback(cxt)) + s.NoError(tx.Rollback(ctx)) } func (s *restoresSuite) Test_restoreBase_disableTriggers() { diff --git a/internal/db/postgres/restorers/table.go b/internal/db/postgres/restorers/table.go index 5ac2b689..a1680893 100644 --- a/internal/db/postgres/restorers/table.go +++ b/internal/db/postgres/restorers/table.go @@ -94,7 +94,10 @@ func (td *TableRestorer) Execute(ctx context.Context, conn *pgx.Conn) error { if td.opt.ExitOnError { return fmt.Errorf("unable to restore table: %w", err) } - log.Warn().Err(err).Msg("unable to restore table") + log.Warn(). + Err(err). + Str("objectName", td.DebugInfo()). + Msg("unable to restore table") return nil } diff --git a/internal/db/postgres/restorers/table_test.go b/internal/db/postgres/restorers/table_test.go new file mode 100644 index 00000000..6246462d --- /dev/null +++ b/internal/db/postgres/restorers/table_test.go @@ -0,0 +1,129 @@ +package restorers + +import ( + "bytes" + "compress/gzip" + "context" + + "github.com/stretchr/testify/mock" + + "github.com/greenmaskio/greenmask/internal/db/postgres/pgrestore" + "github.com/greenmaskio/greenmask/internal/db/postgres/toc" + "github.com/greenmaskio/greenmask/internal/utils/testutils" +) + +func (s *restoresSuite) Test_TableRestorer_check_triggers_errors() { + s.Run("check triggers causes error by default", func() { + // Given + ctx := context.Background() + schemaName := "public" + tableName := "orders" + fileName := "test_table" + copyStmt := "COPY orders (id, user_id, order_amount, raise_error) FROM stdin;" + entry := &toc.Entry{ + Namespace: &schemaName, + Tag: &tableName, + FileName: &fileName, + CopyStmt: ©Stmt, + } + data := "3\t1\t100.50\tTest exception\n" + + "4\t1\t200.75\tTest exception\n" + buf := new(bytes.Buffer) + gzData := gzip.NewWriter(buf) + _, err := gzData.Write([]byte(data)) + s.Require().NoError(err) + err = gzData.Flush() + s.Require().NoError(err) + err = gzData.Close() + s.Require().NoError(err) + objSrc := &readCloserMock{Buffer: buf} + + st := new(testutils.StorageMock) + st.On("GetObject", ctx, mock.Anything).Return(objSrc, nil) + opt := &pgrestore.DataSectionSettings{ + ExitOnError: true, + } + tr := NewTableRestorer(entry, st, opt) + + conn, err := s.GetConnectionWithUser(ctx, s.nonSuperUser, s.nonSuperUserPassword) + err = tr.Execute(ctx, conn) + s.Require().ErrorContains(err, "Test exception (code P0001)") + }) + + s.Run("disable triggers", func() { + ctx := context.Background() + schemaName := "public" + tableName := "orders" + fileName := "test_table" + copyStmt := "COPY orders (id, user_id, order_amount, raise_error) FROM stdin;" + entry := &toc.Entry{ + Namespace: &schemaName, + Tag: &tableName, + FileName: &fileName, + CopyStmt: ©Stmt, + } + data := "3\t1\t100.50\tTest exception\n" + + "4\t1\t200.75\tTest exception\n" + buf := new(bytes.Buffer) + gzData := gzip.NewWriter(buf) + _, err := gzData.Write([]byte(data)) + s.Require().NoError(err) + err = gzData.Flush() + s.Require().NoError(err) + err = gzData.Close() + s.Require().NoError(err) + objSrc := &readCloserMock{Buffer: buf} + + st := new(testutils.StorageMock) + st.On("GetObject", ctx, mock.Anything).Return(objSrc, nil) + opt := &pgrestore.DataSectionSettings{ + ExitOnError: true, + DisableTriggers: true, + SuperUser: s.GetSuperUser(), + } + tr := NewTableRestorer(entry, st, opt) + + conn, err := s.GetConnectionWithUser(ctx, s.nonSuperUser, s.nonSuperUserPassword) + err = tr.Execute(ctx, conn) + s.Require().NoError(err) + }) + + s.Run("session_replication_role is replica", func() { + ctx := context.Background() + schemaName := "public" + tableName := "orders" + fileName := "test_table" + copyStmt := "COPY orders (id, user_id, order_amount, raise_error) FROM stdin;" + entry := &toc.Entry{ + Namespace: &schemaName, + Tag: &tableName, + FileName: &fileName, + CopyStmt: ©Stmt, + } + data := "3\t1\t100.50\tTest exception\n" + + "4\t1\t200.75\tTest exception\n" + buf := new(bytes.Buffer) + gzData := gzip.NewWriter(buf) + _, err := gzData.Write([]byte(data)) + s.Require().NoError(err) + err = gzData.Flush() + s.Require().NoError(err) + err = gzData.Close() + s.Require().NoError(err) + objSrc := &readCloserMock{Buffer: buf} + + st := new(testutils.StorageMock) + st.On("GetObject", ctx, mock.Anything).Return(objSrc, nil) + opt := &pgrestore.DataSectionSettings{ + ExitOnError: true, + UseSessionReplicationRoleReplica: true, + SuperUser: s.GetSuperUser(), + } + tr := NewTableRestorer(entry, st, opt) + + conn, err := s.GetConnectionWithUser(ctx, s.nonSuperUser, s.nonSuperUserPassword) + err = tr.Execute(ctx, conn) + s.Require().NoError(err) + }) + +}