Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch import metadata #35

Closed
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [2.0.2] - 2024-08-21
### Changed
- Use batch import

## [2.0.1] - 2024-08-21
### Changed
- Updates to the `README` and `PYPIDOC`
Expand Down
17 changes: 14 additions & 3 deletions assets/dags/mwaa_dr/framework/model/base_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from airflow import settings
from mwaa_dr.framework.model.dependency_model import DependencyModel
import itertools

S3 = "S3"

Expand Down Expand Up @@ -347,19 +348,29 @@ def restore(self, **context):
"""
backup_file = self.read(context)

restore_sql = ""
if self.columns:
restore_sql = f"COPY {self.name} ({', '.join(self.columns)}) FROM STDIN WITH (FORMAT CSV, HEADER FALSE, DELIMITER '|')"
else:
restore_sql = f"COPY {self.name} FROM STDIN WITH (FORMAT CSV, HEADER FALSE, DELIMITER '|')"
print(f"Restore SQL: {restore_sql}")

conn = settings.engine.raw_connection()
cursor = None
try:
cursor = conn.cursor()
cursor.copy_expert(restore_sql, backup_file)
conn.commit()
insert_counter = 0
with backup_file as file:
while True:
batch = list(itertools.islice(file, self.batch_size))
if not batch:
break
cursor.copy_expert(restore_sql, StringIO("".join(batch)))
conn.commit()
insert_counter += len(batch)
print(f"Inserted {insert_counter} records")
finally:
if cursor:
cursor.close()
conn.close()
backup_file.close()

Expand Down
5 changes: 4 additions & 1 deletion assets/dags/mwaa_dr/v_2_4/dr_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def dag_run(self, model: DependencyModel[BaseTable]) -> BaseTable:
"start_date",
"state",
],
export_filter="dag_id != 'backup_metadata'",
mathiasflorin marked this conversation as resolved.
Show resolved Hide resolved
export_mappings={"conf": "'\\x' || encode(conf,'hex') as conf"},
storage_type=self.storage_type,
path_prefix=self.path_prefix,
Expand Down Expand Up @@ -234,7 +235,7 @@ def task_instance(self, model: DependencyModel[BaseTable]) -> BaseTable:
export_mappings={
"executor_config": "'\\x' || encode(executor_config,'hex') as executor_config"
},
export_filter="state NOT IN ('running','restarting','queued','scheduled', 'up_for_retry','up_for_reschedule')",
export_filter="state NOT IN ('running','restarting','queued','scheduled', 'up_for_retry','up_for_reschedule') AND dag_id != 'backup_metadata'",
storage_type=self.storage_type,
path_prefix=self.path_prefix,
batch_size=self.batch_size,
Expand Down Expand Up @@ -311,6 +312,7 @@ def task_fail(self, model: DependencyModel[BaseTable]) -> BaseTable:
"start_date",
"task_id",
],
export_filter="dag_id != 'backup_metadata'",
storage_type=self.storage_type,
path_prefix=self.path_prefix,
batch_size=self.batch_size,
Expand Down Expand Up @@ -369,6 +371,7 @@ def job(self, model: DependencyModel[BaseTable]) -> BaseTable:
"state",
"unixname",
],
export_filter="dag_id != 'backup_metadata'",
storage_type=self.storage_type,
path_prefix=self.path_prefix,
batch_size=self.batch_size,
Expand Down
113 changes: 105 additions & 8 deletions tests/unit/mwaa_dr/framework/model/test_base_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,39 +459,136 @@ def test_read(mock_context):
with patch.object(table_for_local_fs, "read_from_local", return_value=buffer):
expect(table_for_local_fs.read(dict())).to.be(buffer)

def test_restore_cursor_exception(self, mock_context, mock_sql_raw_connection):
task_instance = BaseTable(
name="task_instance", model=DependencyModel(), columns=["dag_id", "state"]
)

# Create sample data
sample_data = "test|running\r\n"

with (
io.StringIO(sample_data) as store,
patch.object(task_instance, "read", return_value=store),
patch(
"mwaa_dr.framework.model.base_table.settings.engine.raw_connection"
) as mock_raw_connection,
):
mock_conn = mock_raw_connection.return_value
mock_conn.cursor.side_effect = Exception("Cursor creation failed")

with pytest.raises(Exception, match="Cursor creation failed"):
task_instance.restore(**mock_context)

# Ensure cursor was never created
assert mock_conn.cursor.call_count == 1
assert mock_conn.commit.call_count == 0
assert mock_conn.close.call_count == 1

def test_restore_multi_columns(self, mock_context, mock_sql_raw_connection):
task_instance = BaseTable(
name="task_instance", model=DependencyModel(), columns=["dag_id", "state"]
)

with (
io.BytesIO(b"test,running\r\n") as store,
io.StringIO("test|running\r\n") as store,
patch.object(task_instance, "read", return_value=store),
):
task_instance.restore(**mock_context)
expect(task_instance.read.call_count).to.equal(1)
expect(task_instance.read.call_args[0][0]).to.equal(mock_context)
mock_sql_raw_connection.return_value.cursor.return_value.copy_expert.assert_called_with(
"COPY task_instance (dag_id, state) FROM STDIN WITH (FORMAT CSV, HEADER FALSE, DELIMITER '|')",
store,
restore_sql, string_io = (
mock_sql_raw_connection.return_value.cursor.return_value.copy_expert.call_args.args
)
expected_sql = "COPY task_instance (dag_id, state) FROM STDIN WITH (FORMAT CSV, HEADER FALSE, DELIMITER '|')"
expect(restore_sql).equal(expected_sql)
expect(string_io.getvalue()).equal("test|running\r\n")
mock_sql_raw_connection.return_value.commit.assert_called_once()
mock_sql_raw_connection.return_value.close.assert_called_once()

def test_restore_batch_size(self, mock_context, mock_sql_raw_connection):
task_instance = BaseTable(
name="task_instance", model=DependencyModel(), columns=["dag_id", "state"]
)

# Create a large amount of sample data
sample_data = (
"\r\n".join(
[f"dag_id_{i}|state_{i}" for i in range(task_instance.batch_size)]
)
+ "\r\n"
)

with (
io.StringIO(sample_data) as store,
patch.object(task_instance, "read", return_value=store),
):
task_instance.restore(**mock_context)
expect(task_instance.read.call_count).to.equal(1)
expect(task_instance.read.call_args[0][0]).to.equal(mock_context)
call_list_args = (
mock_sql_raw_connection.return_value.cursor.return_value.copy_expert.call_args_list
)
restore_sql_1, string_io_1 = call_list_args[0].args
expected_sql = "COPY task_instance (dag_id, state) FROM STDIN WITH (FORMAT CSV, HEADER FALSE, DELIMITER '|')"
expect(restore_sql_1).equal(expected_sql)
expect(string_io_1.getvalue()).equal(sample_data)
mock_sql_raw_connection.return_value.commit.assert_called()
expect(mock_sql_raw_connection.return_value.commit.call_count).to.equal(1)
mock_sql_raw_connection.return_value.close.assert_called_once()

def test_restore_more_than_batch_size(self, mock_context, mock_sql_raw_connection):
task_instance = BaseTable(
name="task_instance", model=DependencyModel(), columns=["dag_id", "state"]
)

# Create a large amount of sample data
first_batch = (
"\r\n".join(
[f"dag_id_{i}|state_{i}" for i in range(task_instance.batch_size)]
)
+ "\r\n"
)
second_batch = "\r\n".join([f"dag_id_{i}|state_{i}" for i in range(2)]) + "\r\n"
sample_data = first_batch + second_batch

with (
io.StringIO(sample_data) as store,
patch.object(task_instance, "read", return_value=store),
):
task_instance.restore(**mock_context)
expect(task_instance.read.call_count).to.equal(1)
expect(task_instance.read.call_args[0][0]).to.equal(mock_context)
call_list_args = (
mock_sql_raw_connection.return_value.cursor.return_value.copy_expert.call_args_list
)
restore_sql_1, string_io_1 = call_list_args[0].args
expected_sql = "COPY task_instance (dag_id, state) FROM STDIN WITH (FORMAT CSV, HEADER FALSE, DELIMITER '|')"
expect(restore_sql_1).equal(expected_sql)
expect(string_io_1.getvalue()).equal(first_batch)
restore_sql_2, string_io_2 = call_list_args[1].args
expect(restore_sql_2).equal(expected_sql)
expect(string_io_2.getvalue()).equal(second_batch)
mock_sql_raw_connection.return_value.commit.assert_called()
expect(mock_sql_raw_connection.return_value.commit.call_count).to.equal(2)
mock_sql_raw_connection.return_value.close.assert_called_once()

def test_restore_no_columns(self, mock_context, mock_sql_raw_connection):
task_instance = BaseTable(name="task_instance", model=DependencyModel())

with (
io.BytesIO(b"test,running\r\n") as store,
io.StringIO("test,running\r\n") as store,
patch.object(task_instance, "read", return_value=store),
):
task_instance.restore(**mock_context)
expect(task_instance.read.call_count).to.equal(1)
expect(task_instance.read.call_args[0][0]).to.equal(mock_context)
mock_sql_raw_connection.return_value.cursor.return_value.copy_expert.assert_called_with(
"COPY task_instance FROM STDIN WITH (FORMAT CSV, HEADER FALSE, DELIMITER '|')",
store,
restore_sql, string_io = (
mock_sql_raw_connection.return_value.cursor.return_value.copy_expert.call_args.args
)
expected_sql = "COPY task_instance FROM STDIN WITH (FORMAT CSV, HEADER FALSE, DELIMITER '|')"
expect(restore_sql).equal(expected_sql)
expect(string_io.getvalue()).equal("test,running\r\n")
mock_sql_raw_connection.return_value.commit.assert_called_once()
mock_sql_raw_connection.return_value.close.assert_called_once()

Expand Down
5 changes: 4 additions & 1 deletion tests/unit/mwaa_dr/v_2_4/test_dr_factory_2_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def test_dag_run_creation(self):
"state",
],
expected_mappings={"conf": "'\\x' || encode(conf,'hex') as conf"},
expected_export_filter="dag_id != 'backup_metadata'",
)

def test_task_instance_creation(self):
Expand Down Expand Up @@ -154,7 +155,7 @@ def test_task_instance_creation(self):
expected_mappings={
"executor_config": "'\\x' || encode(executor_config,'hex') as executor_config"
},
expected_export_filter="state NOT IN ('running','restarting','queued','scheduled', 'up_for_retry','up_for_reschedule')",
expected_export_filter="state NOT IN ('running','restarting','queued','scheduled', 'up_for_retry','up_for_reschedule') AND dag_id != 'backup_metadata'",
)

def test_slot_pool(self):
Expand Down Expand Up @@ -204,6 +205,7 @@ def test_task_fail(self):
"start_date",
"task_id",
],
expected_export_filter="dag_id != 'backup_metadata'",
)

def test_job(self):
Expand All @@ -223,6 +225,7 @@ def test_job(self):
"state",
"unixname",
],
expected_export_filter="dag_id != 'backup_metadata'",
)

def test_trigger(self):
Expand Down
Loading