Skip to content

Commit

Permalink
fix: various fixes of bugs encountered during production
Browse files Browse the repository at this point in the history
- No longer inferring schemas if the first row_msg in a batch is a DELETE operation
- Instead uses sqlalchemy to reflect the schema
- In pg9.6, pg_current_xlog_location wasn't reliable which would cause the message consumer to hang until new data was flushed to WAL
- Doesn't fix but was the cause for dlt-hub/dlt#2229 (was able to reproduce in the added test case)
- Some minor refactoring
  • Loading branch information
Nicolas ESTRADA committed Feb 19, 2025
1 parent 864b746 commit a591618
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 86 deletions.
10 changes: 5 additions & 5 deletions sources/pg_legacy_replication/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def replication_resource(slot_name: str) -> Iterable[TDataItem]:
advance_slot(start_lsn, slot_name, credentials)

# continue until last message in replication slot
upto_lsn = get_max_lsn(credentials)
upto_lsn = get_max_lsn(credentials, slot_name)
if upto_lsn is None:
return

Expand Down Expand Up @@ -182,10 +182,10 @@ def init_replication(
- When `take_snapshots` is `True`, the function configures a snapshot isolation level for consistent table snapshots.
"""
rep_conn = get_rep_conn(credentials)
rep_cur = rep_conn.cursor()
if reset:
drop_replication_slot(slot_name, rep_cur)
slot = create_replication_slot(slot_name, rep_cur)
with rep_conn.cursor() as rep_cur:
if reset:
drop_replication_slot(slot_name, rep_cur)
slot = create_replication_slot(slot_name, rep_cur)

# Close connection if no snapshots are needed
if not take_snapshots:
Expand Down
169 changes: 117 additions & 52 deletions sources/pg_legacy_replication/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import hashlib
from collections import defaultdict
from dataclasses import dataclass, field
from functools import partial
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -36,6 +37,7 @@
arrow_helpers as arrow,
engine_from_credentials,
)
from dlt.sources.sql_database.schema_types import sqla_col_to_column_schema
from psycopg2.extensions import connection as ConnectionExt, cursor
from psycopg2.extras import (
LogicalReplicationConnection,
Expand Down Expand Up @@ -153,26 +155,34 @@ def drop_replication_slot(name: str, cur: ReplicationCursor) -> None:
)


def get_max_lsn(credentials: ConnectionStringCredentials) -> Optional[int]:
def get_max_lsn(
credentials: ConnectionStringCredentials, slot_name: str
) -> Optional[int]:
"""
Returns maximum Log Sequence Number (LSN).
Returns None if the replication slot is empty.
Does not consume the slot, i.e. messages are not flushed.
"""
cur = _get_conn(credentials).cursor()
conn = _get_conn(credentials)
try:
loc_fn = (
"pg_current_xlog_location"
if get_pg_version(cur) < 100000
else "pg_current_wal_lsn"
)
# subtract '0/0' to convert pg_lsn type to int (https://stackoverflow.com/a/73738472)
cur.execute(f"SELECT {loc_fn}() - '0/0' as max_lsn;")
lsn: int = cur.fetchone()[0]
return lsn
with conn.cursor() as cur:
pg_version = get_pg_version(cur)
lsn_field = "lsn" if pg_version >= 100000 else "location"
# subtract '0/0' to convert pg_lsn type to int (https://stackoverflow.com/a/73738472)
cur.execute(
f"""
SELECT {lsn_field} - '0/0' AS max_lsn
FROM pg_logical_slot_peek_binary_changes(%s, NULL, NULL)
ORDER BY {lsn_field} DESC
LIMIT 1;
""",
(slot_name,),
)
row = cur.fetchone()
return row[0] if row else None # type: ignore[no-any-return]
finally:
cur.connection.close()
conn.close()


def lsn_int_to_hex(lsn: int) -> str:
Expand All @@ -194,15 +204,16 @@ def advance_slot(
the behavior of that method seems odd when used outside of `consume_stream`.
"""
assert upto_lsn > 0
cur = _get_conn(credentials).cursor()
conn = _get_conn(credentials)
try:
# There is unfortunately no way in pg9.6 to manually advance the replication slot
if get_pg_version(cur) > 100000:
cur.execute(
f"SELECT * FROM pg_replication_slot_advance('{slot_name}', '{lsn_int_to_hex(upto_lsn)}');"
)
with conn.cursor() as cur:
# There is unfortunately no way in pg9.6 to manually advance the replication slot
if get_pg_version(cur) > 100000:
cur.execute(
f"SELECT * FROM pg_replication_slot_advance('{slot_name}', '{lsn_int_to_hex(upto_lsn)}');"
)
finally:
cur.connection.close()
conn.close()


def _get_conn(
Expand Down Expand Up @@ -243,11 +254,13 @@ class MessageConsumer:

def __init__(
self,
credentials: ConnectionStringCredentials,
upto_lsn: int,
table_qnames: Set[str],
repl_options: DefaultDict[str, ReplicationOptions],
target_batch_size: int = 1000,
) -> None:
self.credentials = credentials
self.upto_lsn = upto_lsn
self.table_qnames = table_qnames
self.target_batch_size = target_batch_size
Expand Down Expand Up @@ -280,15 +293,22 @@ def process_msg(self, msg: ReplicationMessage) -> None:
row_msg = RowMessage()
try:
row_msg.ParseFromString(msg.payload)
lsn = msg.data_start
assert row_msg.op != Op.UNKNOWN, f"Unsupported operation : {row_msg}"
logger.debug(
"op: %s, current lsn: %s, max lsn: %s",
Op.Name(row_msg.op),
lsn,
self.upto_lsn,
)

if row_msg.op == Op.BEGIN:
# self.last_commit_ts = _epoch_micros_to_datetime(row_msg.commit_time)
pass
elif row_msg.op == Op.COMMIT:
self.process_commit(lsn=msg.data_start)
self.process_commit(lsn=lsn)
else: # INSERT, UPDATE or DELETE
self.process_change(row_msg, lsn=msg.data_start)
self.process_change(row_msg, lsn=lsn)
except StopReplication:
raise
except Exception:
Expand Down Expand Up @@ -317,26 +337,31 @@ def process_change(self, msg: RowMessage, lsn: int) -> None:
if msg.table not in self.table_qnames:
return
table_name = msg.table.split(".")[1]
table_schema = self.get_table_schema(msg, table_name)
table_schema = self.get_table_schema(msg)
data_item = gen_data_item(
msg, table_schema["columns"], lsn, **self.repl_options[table_name]
)
self.data_items[table_name].append(data_item)

def get_table_schema(self, msg: RowMessage, table_name: str) -> TTableSchema:
def get_table_schema(self, msg: RowMessage) -> TTableSchema:
"""Given a row message, calculates or fetches a table schema."""
schema, table_name = msg.table.split(".")
last_schema = self.last_table_schema.get(table_name)

# Used cached schema if the operation is a DELETE since the inferred one will always be less precise
if msg.op == Op.DELETE and last_schema:
# Used cached schema if the operation is a DELETE
if msg.op == Op.DELETE:
if last_schema is None:
# If absent than reflect it using sqlalchemy
last_schema = self._fetch_table_schema_with_sqla(schema, table_name)
self.last_table_schema[table_name] = last_schema
return last_schema

# Return cached schema if hash matches
current_hash = hash_typeinfo(msg.new_typeinfo)
if current_hash == self.last_table_hashes.get(table_name):
return self.last_table_schema[table_name]

new_schema = infer_table_schema(msg, **self.repl_options[table_name])
new_schema = infer_table_schema(msg, self.repl_options[table_name])
if last_schema is None:
# Cache the inferred schema and hash if it is not already cached
self.last_table_schema[table_name] = new_schema
Expand All @@ -351,6 +376,33 @@ def get_table_schema(self, msg: RowMessage, table_name: str) -> TTableSchema:

return new_schema

def _fetch_table_schema_with_sqla(
self, schema: str, table_name: str
) -> TTableSchema:
"""Last resort function used to fetch the table schema from the database"""
engine = engine_from_credentials(self.credentials)
to_col_schema = partial(
sqla_col_to_column_schema, reflection_level="full_with_precision"
)
try:
metadata = MetaData(schema=schema)
table = Table(table_name, metadata, autoload_with=engine)
options = self.repl_options[table_name]
included_columns = options.get("included_columns")
columns = {
col["name"]: col
for c in table.columns
if (col := to_col_schema(c)) is not None
and (not included_columns or c.name in included_columns)
}

return TTableSchema(
name=table_name,
columns=add_replication_columns(columns, **options),
)
finally:
engine.dispose()


def hash_typeinfo(new_typeinfo: Sequence[TypeInfo]) -> int:
"""Generate a hash for the entire new_typeinfo list by hashing each TypeInfo message."""
Expand Down Expand Up @@ -386,20 +438,30 @@ def __iter__(self) -> Iterator[TableItems]:
Maintains LSN of last consumed commit message in object state.
Advances the slot only when all messages have been consumed.
"""
cur = get_rep_conn(self.credentials).cursor()
conn = get_rep_conn(self.credentials)
consumer = MessageConsumer(
credentials=self.credentials,
upto_lsn=self.upto_lsn,
table_qnames=self.table_qnames,
repl_options=self.repl_options,
target_batch_size=self.target_batch_size,
)

cur = conn.cursor()
try:
cur.start_replication(slot_name=self.slot_name, start_lsn=self.start_lsn)
cur.consume_stream(consumer)
except StopReplication: # completed batch or reached `upto_lsn`
yield from self.flush_batch(cur, consumer)
finally:
cur.connection.close()
logger.debug(
"Closing connection... last_commit_lsn: %s, generated_all: %s, feedback_ts: %s",
self.last_commit_lsn,
self.generated_all,
cur.feedback_timestamp,
)
cur.close()
conn.close()

def flush_batch(
self, cur: ReplicationCursor, consumer: MessageConsumer
Expand Down Expand Up @@ -489,65 +551,68 @@ def emit_arrow_table(
)


def infer_table_schema(
msg: RowMessage,
include_lsn: bool = True,
include_deleted_ts: bool = True,
include_commit_ts: bool = False,
include_tx_id: bool = False,
included_columns: Optional[Set[str]] = None,
**_: Any,
) -> TTableSchema:
def infer_table_schema(msg: RowMessage, options: ReplicationOptions) -> TTableSchema:
"""Infers the table schema from the replication message and optional hints."""
# Choose the correct source based on operation type
is_change = msg.op != Op.DELETE
tuples = msg.new_tuple if is_change else msg.old_tuple
schema = TTableSchema(name=msg.table.split(".")[1])

# Filter and map columns, conditionally using `new_typeinfo` when available
schema["columns"] = {
assert msg.op != Op.DELETE
included_columns = options.get("included_columns")
columns = {
col_name: _to_dlt_column_schema(
col_name, datum=col, type_info=msg.new_typeinfo[i] if is_change else None
col_name, datum=col, type_info=msg.new_typeinfo[i]
)
for i, col in enumerate(tuples)
for i, col in enumerate(msg.new_tuple)
if (col_name := _actual_column_name(col))
and (not included_columns or col_name in included_columns)
}

# Add replication columns
return TTableSchema(
name=msg.table.split(".")[1],
columns=add_replication_columns(columns, **options),
)


def add_replication_columns(
columns: TTableSchemaColumns,
*,
include_lsn: bool = True,
include_deleted_ts: bool = True,
include_commit_ts: bool = False,
include_tx_id: bool = False,
**_: Any,
) -> TTableSchemaColumns:
if include_lsn:
schema["columns"]["_pg_lsn"] = {
columns["_pg_lsn"] = {
"data_type": "bigint",
"name": "_pg_lsn",
"nullable": True,
}
if include_deleted_ts:
schema["columns"]["_pg_deleted_ts"] = {
columns["_pg_deleted_ts"] = {
"data_type": "timestamp",
"name": "_pg_deleted_ts",
"nullable": True,
}
if include_commit_ts:
schema["columns"]["_pg_commit_ts"] = {
columns["_pg_commit_ts"] = {
"data_type": "timestamp",
"name": "_pg_commit_ts",
"nullable": True,
}
if include_tx_id:
schema["columns"]["_pg_tx_id"] = {
columns["_pg_tx_id"] = {
"data_type": "bigint",
"name": "_pg_tx_id",
"nullable": True,
"precision": 32,
}

return schema
return columns


def gen_data_item(
msg: RowMessage,
column_schema: TTableSchemaColumns,
lsn: int,
*,
include_lsn: bool = True,
include_deleted_ts: bool = True,
include_commit_ts: bool = False,
Expand Down
Loading

0 comments on commit a591618

Please sign in to comment.