Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch import metadata #35

Closed
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [2.0.2] - 2024-08-21
### Changed
- Use batch import

## [2.0.1] - 2024-08-21
### Changed
- Updates to the `README` and `PYPIDOC`
Expand Down
21 changes: 18 additions & 3 deletions assets/dags/mwaa_dr/framework/model/base_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,19 +347,34 @@ def restore(self, **context):
"""
backup_file = self.read(context)

restore_sql = ""
if self.columns:
restore_sql = f"COPY {self.name} ({', '.join(self.columns)}) FROM STDIN WITH (FORMAT CSV, HEADER FALSE, DELIMITER '|')"
else:
restore_sql = f"COPY {self.name} FROM STDIN WITH (FORMAT CSV, HEADER FALSE, DELIMITER '|')"
print(f"Restore SQL: {restore_sql}")

conn = settings.engine.raw_connection()
cursor = None
try:
cursor = conn.cursor()
cursor.copy_expert(restore_sql, backup_file)
conn.commit()
insert_counter = 0
with backup_file as file:
batch = []
for line in file:
mathiasflorin marked this conversation as resolved.
Show resolved Hide resolved
batch.append(line)
if len(batch) == self.batch_size:
cursor.copy_expert(restore_sql, StringIO("".join(batch)))
conn.commit()
insert_counter += self.batch_size
batch = []
if batch:
insert_counter += len(batch)
cursor.copy_expert(restore_sql, StringIO("".join(batch)))
conn.commit()
print(f"Inserted {insert_counter} records")
finally:
if cursor:
cursor.close()
conn.close()
backup_file.close()

Expand Down
5 changes: 4 additions & 1 deletion assets/dags/mwaa_dr/v_2_4/dr_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def dag_run(self, model: DependencyModel[BaseTable]) -> BaseTable:
"start_date",
"state",
],
export_filter="dag_id != 'backup_metadata'",
mathiasflorin marked this conversation as resolved.
Show resolved Hide resolved
export_mappings={"conf": "'\\x' || encode(conf,'hex') as conf"},
storage_type=self.storage_type,
path_prefix=self.path_prefix,
Expand Down Expand Up @@ -234,7 +235,7 @@ def task_instance(self, model: DependencyModel[BaseTable]) -> BaseTable:
export_mappings={
"executor_config": "'\\x' || encode(executor_config,'hex') as executor_config"
},
export_filter="state NOT IN ('running','restarting','queued','scheduled', 'up_for_retry','up_for_reschedule')",
export_filter="state NOT IN ('running','restarting','queued','scheduled', 'up_for_retry','up_for_reschedule') AND dag_id != 'backup_metadata'",
storage_type=self.storage_type,
path_prefix=self.path_prefix,
batch_size=self.batch_size,
Expand Down Expand Up @@ -311,6 +312,7 @@ def task_fail(self, model: DependencyModel[BaseTable]) -> BaseTable:
"start_date",
"task_id",
],
export_filter="dag_id != 'backup_metadata'",
storage_type=self.storage_type,
path_prefix=self.path_prefix,
batch_size=self.batch_size,
Expand Down Expand Up @@ -369,6 +371,7 @@ def job(self, model: DependencyModel[BaseTable]) -> BaseTable:
"state",
"unixname",
],
export_filter="dag_id != 'backup_metadata'",
storage_type=self.storage_type,
path_prefix=self.path_prefix,
batch_size=self.batch_size,
Expand Down
56 changes: 48 additions & 8 deletions tests/unit/mwaa_dr/framework/model/test_base_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,33 +465,73 @@ def test_restore_multi_columns(self, mock_context, mock_sql_raw_connection):
)

with (
io.BytesIO(b"test,running\r\n") as store,
io.StringIO("test|running\r\n") as store,
patch.object(task_instance, "read", return_value=store),
):
task_instance.restore(**mock_context)
expect(task_instance.read.call_count).to.equal(1)
expect(task_instance.read.call_args[0][0]).to.equal(mock_context)
mock_sql_raw_connection.return_value.cursor.return_value.copy_expert.assert_called_with(
"COPY task_instance (dag_id, state) FROM STDIN WITH (FORMAT CSV, HEADER FALSE, DELIMITER '|')",
store,
restore_sql, string_io = (
mock_sql_raw_connection.return_value.cursor.return_value.copy_expert.call_args.args
)
expected_sql = "COPY task_instance (dag_id, state) FROM STDIN WITH (FORMAT CSV, HEADER FALSE, DELIMITER '|')"
expect(restore_sql).equal(expected_sql)
expect(string_io.getvalue()).equal("test|running\r\n")
mock_sql_raw_connection.return_value.commit.assert_called_once()
mock_sql_raw_connection.return_value.close.assert_called_once()

def test_restore_batch_size(self, mock_context, mock_sql_raw_connection):
task_instance = BaseTable(
name="task_instance", model=DependencyModel(), columns=["dag_id", "state"]
)

# Create a large amount of sample data
first_batch = (
"\r\n".join(
[f"dag_id_{i}|state_{i}" for i in range(task_instance.batch_size)]
)
+ "\r\n"
)
second_batch = "\r\n".join([f"dag_id_{i}|state_{i}" for i in range(2)]) + "\r\n"
sample_data = first_batch + second_batch

with (
io.StringIO(sample_data) as store,
patch.object(task_instance, "read", return_value=store),
):
task_instance.restore(**mock_context)
expect(task_instance.read.call_count).to.equal(1)
expect(task_instance.read.call_args[0][0]).to.equal(mock_context)
call_list_args = (
mock_sql_raw_connection.return_value.cursor.return_value.copy_expert.call_args_list
)
restore_sql_1, string_io_1 = call_list_args[0].args
expected_sql = "COPY task_instance (dag_id, state) FROM STDIN WITH (FORMAT CSV, HEADER FALSE, DELIMITER '|')"
expect(restore_sql_1).equal(expected_sql)
expect(string_io_1.getvalue()).equal(first_batch)
restore_sql_2, string_io_2 = call_list_args[1].args
expect(restore_sql_2).equal(expected_sql)
expect(string_io_2.getvalue()).equal(second_batch)
mock_sql_raw_connection.return_value.commit.assert_called()
expect(mock_sql_raw_connection.return_value.commit.call_count).to.equal(2)
mock_sql_raw_connection.return_value.close.assert_called_once()

def test_restore_no_columns(self, mock_context, mock_sql_raw_connection):
task_instance = BaseTable(name="task_instance", model=DependencyModel())

with (
io.BytesIO(b"test,running\r\n") as store,
io.StringIO("test,running\r\n") as store,
patch.object(task_instance, "read", return_value=store),
):
task_instance.restore(**mock_context)
expect(task_instance.read.call_count).to.equal(1)
expect(task_instance.read.call_args[0][0]).to.equal(mock_context)
mock_sql_raw_connection.return_value.cursor.return_value.copy_expert.assert_called_with(
"COPY task_instance FROM STDIN WITH (FORMAT CSV, HEADER FALSE, DELIMITER '|')",
store,
restore_sql, string_io = (
mock_sql_raw_connection.return_value.cursor.return_value.copy_expert.call_args.args
)
expected_sql = "COPY task_instance FROM STDIN WITH (FORMAT CSV, HEADER FALSE, DELIMITER '|')"
expect(restore_sql).equal(expected_sql)
expect(string_io.getvalue()).equal("test,running\r\n")
mock_sql_raw_connection.return_value.commit.assert_called_once()
mock_sql_raw_connection.return_value.close.assert_called_once()

Expand Down
5 changes: 4 additions & 1 deletion tests/unit/mwaa_dr/v_2_4/test_dr_factory_2_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def test_dag_run_creation(self):
"state",
],
expected_mappings={"conf": "'\\x' || encode(conf,'hex') as conf"},
expected_export_filter="dag_id != 'backup_metadata'",
)

def test_task_instance_creation(self):
Expand Down Expand Up @@ -154,7 +155,7 @@ def test_task_instance_creation(self):
expected_mappings={
"executor_config": "'\\x' || encode(executor_config,'hex') as executor_config"
},
expected_export_filter="state NOT IN ('running','restarting','queued','scheduled', 'up_for_retry','up_for_reschedule')",
expected_export_filter="state NOT IN ('running','restarting','queued','scheduled', 'up_for_retry','up_for_reschedule') AND dag_id != 'backup_metadata'",
)

def test_slot_pool(self):
Expand Down Expand Up @@ -204,6 +205,7 @@ def test_task_fail(self):
"start_date",
"task_id",
],
expected_export_filter="dag_id != 'backup_metadata'",
)

def test_job(self):
Expand All @@ -223,6 +225,7 @@ def test_job(self):
"state",
"unixname",
],
expected_export_filter="dag_id != 'backup_metadata'",
)

def test_trigger(self):
Expand Down
Loading