Skip to content

Commit

Permalink
Encode Kafka partition key (#553)
Browse files Browse the repository at this point in the history
  • Loading branch information
gtopper authored Jan 9, 2025
1 parent 9b970f7 commit 77713df
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
14 changes: 9 additions & 5 deletions integration/test_kafka_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ def test_kafka_target(kafka_topic_setup_teardown):
assert record.value.decode("UTF-8") == json.dumps(event.body, default=str)


async def async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardown):
async def async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardown, partition_key):
kafka_consumer = kafka_topic_setup_teardown

controller = build_flow(
[
AsyncEmitSource(),
KafkaTarget(kafka_brokers, topic, sharding_func=lambda _: 0, full_event=True),
KafkaTarget(kafka_brokers, topic, sharding_func=lambda _: partition_key, full_event=True),
]
).run()
events = []
Expand All @@ -115,7 +115,10 @@ async def async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardo
record = next(kafka_consumer)
if event.key is None:
if event.key is None:
assert record.key is None
if isinstance(partition_key, int):
assert record.key is None
else:
assert record.key.decode("UTF-8") == partition_key
else:
assert record.key.decode("UTF-8") == event.key
readback_records.append(json.loads(record.value.decode("UTF-8")))
Expand Down Expand Up @@ -143,5 +146,6 @@ async def async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardo
not kafka_brokers,
reason="KAFKA_BROKERS must be defined to run kafka tests",
)
def test_async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardown):
asyncio.run(async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardown))
@pytest.mark.parametrize("partition_key", [0, "some_string"])
def test_async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardown, partition_key):
asyncio.run(async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardown, partition_key))
8 changes: 5 additions & 3 deletions storey/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,9 +1345,7 @@ async def _do(self, event):
self._producer.close()
return await self._do_downstream(_termination_obj)
else:
key = None
if event.key is not None:
key = stringify_key(event.key).encode("UTF-8")
key = event.key
record = self._event_to_writer_entry(event)
if self._full_event:
record = wrap_event_for_serialization(event, record)
Expand All @@ -1359,6 +1357,10 @@ async def _do(self, event):
partition = sharding_func_result
else:
key = sharding_func_result

if key is not None:
key = stringify_key(key).encode("UTF-8")

future = self._producer.send(self._topic, record, key, partition=partition)
# Prevent garbage collection of event until persisted to kafka
future.add_callback(lambda x: event)
Expand Down

0 comments on commit 77713df

Please sign in to comment.