From 174e0751e094a60cf090cd882268c00b9405a492 Mon Sep 17 00:00:00 2001 From: Flynn Date: Mon, 13 Jun 2022 00:20:22 -0400 Subject: [PATCH] df_to/from_protobuf functions (#25) * add top level conversion functions * flatten to expand * better docs * rm custom word * clear return var * update readme * fix note * cleanup * add to/from protobuf funcs * simplify test --- README.md | 102 ++++++++++++++++++++++++++++++++++------- pbspark/__init__.py | 4 ++ pbspark/_proto.py | 108 ++++++++++++++++++++++++++++++++++++++++++++ tests/test_proto.py | 49 ++++++++++++++++++++ 4 files changed, 247 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index e134420..65644e4 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # pbspark -This package provides a way to convert protobuf messages into pyspark dataframes and vice versa using a pyspark udf. +This package provides a way to convert protobuf messages into pyspark dataframes and vice versa using pyspark `udf`s. ## Installation @@ -26,7 +26,68 @@ message SimpleMessage { } ``` -Using `pbspark` we can decode the messages into spark `StructType` and then flatten them. +### Basic conversion functions + +There are two functions for operating on columns, `to_protobuf` and `from_protobuf`. These operations convert to/from an encoded protobuf column to a column of a struct representing the inferred message structure. `MessageConverter` instances (discussed below) can optionally be passed to these functions. + +```python +from pyspark.sql.session import SparkSession +from example.example_pb2 import SimpleMessage +from pbspark import from_protobuf +from pbspark import to_protobuf + +spark = SparkSession.builder.getOrCreate() + +example = SimpleMessage(name="hello", quantity=5, measure=12.3) +data = [{"value": example.SerializeToString()}] +df_encoded = spark.createDataFrame(data) + +df_decoded = df_encoded.select(from_protobuf(df_encoded.value, SimpleMessage).alias("value")) +df_expanded = df_decoded.select("value.*") +df_expanded.show() + +# +-----+--------+-------+ +# | name|quantity|measure| +# +-----+--------+-------+ +# |hello| 5| 12.3| +# +-----+--------+-------+ + +df_reencoded = df_decoded.select(to_protobuf(df_decoded.value, SimpleMessage).alias("value")) +``` + +There are two helper functions, `df_to_protobuf` and `df_from_protobuf` for use on dataframes. They have a kwarg `expanded`, which will also take care of expanding/contracting the data between the single `value` column used in these examples and a dataframe which contains a column for each message field. `MessageConverter` instances (discussed below) can optionally be passed to these functions. + +```python +from pyspark.sql.session import SparkSession +from example.example_pb2 import SimpleMessage +from pbspark import df_from_protobuf +from pbspark import df_to_protobuf + +spark = SparkSession.builder.getOrCreate() + +example = SimpleMessage(name="hello", quantity=5, measure=12.3) +data = [{"value": example.SerializeToString()}] +df_encoded = spark.createDataFrame(data) + +# expanded=True will perform a `.select("value.*")` after converting, +# resulting in each protobuf field having its own column +df_expanded = df_from_protobuf(df_encoded, SimpleMessage, expanded=True) +df_expanded.show() + +# +-----+--------+-------+ +# | name|quantity|measure| +# +-----+--------+-------+ +# |hello| 5| 12.3| +# +-----+--------+-------+ + +# expanded=True will first pack data using `struct([df[c] for c in df.columns])`, +# use this if the passed dataframe is already expanded +df_reencoded = df_to_protobuf(df_expanded, SimpleMessage, expanded=True) +``` + +### Column conversion using the `MessageConverter` + +The four helper functions above are also available as methods on the `MessageConverter` class. Using an instance of `MessageConverter` we can decode the column of encoded messages into a column of spark `StructType` and then expand the fields. ```python from pyspark.sql.session import SparkSession @@ -37,12 +98,12 @@ spark = SparkSession.builder.getOrCreate() example = SimpleMessage(name="hello", quantity=5, measure=12.3) data = [{"value": example.SerializeToString()}] -df = spark.createDataFrame(data) +df_encoded = spark.createDataFrame(data) mc = MessageConverter() -df_decoded = df.select(mc.from_protobuf(df.value, SimpleMessage).alias("value")) -df_flattened = df_decoded.select("value.*") -df_flattened.show() +df_decoded = df_encoded.select(mc.from_protobuf(df_encoded.value, SimpleMessage).alias("value")) +df_expanded = df_decoded.select("value.*") +df_expanded.show() # +-----+--------+-------+ # | name|quantity|measure| @@ -50,34 +111,37 @@ df_flattened.show() # |hello| 5| 12.3| # +-----+--------+-------+ -df_flattened.schema +df_expanded.schema # StructType(List(StructField(name,StringType,true),StructField(quantity,IntegerType,true),StructField(measure,FloatType,true)) ``` -We can also re-encode them into protobuf strings. +We can also re-encode them into protobuf. ```python df_reencoded = df_decoded.select(mc.to_protobuf(df_decoded.value, SimpleMessage).alias("value")) ``` -For flattened data, we can also (re-)encode after collecting and packing into a struct: +For expanded data, we can also encode after packing into a struct column: ```python from pyspark.sql.functions import struct -df_unflattened = df_flattened.select( - struct([df_flattened[c] for c in df_flattened.columns]).alias("value") +df_unexpanded = df_expanded.select( + struct([df_expanded[c] for c in df_expanded.columns]).alias("value") ) -df_unflattened.show() -df_reencoded = df_unflattened.select( - mc.to_protobuf(df_unflattened.value, SimpleMessage).alias("value") +df_reencoded = df_unexpanded.select( + mc.to_protobuf(df_unexpanded.value, SimpleMessage).alias("value") ) ``` +### Conversion details + Internally, `pbspark` uses protobuf's `MessageToDict`, which deserializes everything into JSON compatible objects by default. The exceptions are * protobuf's bytes type, which `MessageToDict` would decode to a base64-encoded string; `pbspark` will decode any bytes fields directly to a spark `BinaryType`. * protobuf's well known type, Timestamp type, which `MessageToDict` would decode to a string; `pbspark` will decode any Timestamp messages directly to a spark `TimestampType` (via python datetime objects). +### Custom conversion of message types + Custom serde is also supported. Suppose we use our `NestedMessage` from the repository's example and we want to serialize the key and value together into a single string. ```protobuf @@ -115,9 +179,9 @@ spark = SparkSession(sc).builder.getOrCreate() message = ExampleMessage(nested=NestedMessage(key="hello", value="world")) data = [{"value": message.SerializeToString()}] -df = spark.createDataFrame(data) +df_encoded = spark.createDataFrame(data) -df_decoded = df.select(mc.from_protobuf(df.value, ExampleMessage).alias("value")) +df_decoded = df_encoded.select(mc.from_protobuf(df_encoded.value, ExampleMessage).alias("value")) # rather than a struct the value of `nested` is a string df_decoded.select("value.nested").show() @@ -128,6 +192,8 @@ df_decoded.select("value.nested").show() # +-----------+ ``` +### How to write conversion functions + More generally, custom serde functions should be written in the following format. ```python @@ -146,3 +212,7 @@ def decode_nested(s: str, message: NestedMessage, path: str): message.key = key message.value = value ``` + +### Known issues + +`RecursionError` when using self-referencing protobuf messages. Spark schemas do not allow for arbitrary depth, so protobuf messages which are circular- or self-referencing will result in infinite recursion errors when inferring the schema. If you have message structures like this you should resort to creating custom conversion functions, which forcibly limit the structural depth when converting these messages. \ No newline at end of file diff --git a/pbspark/__init__.py b/pbspark/__init__.py index a285cc3..904b734 100644 --- a/pbspark/__init__.py +++ b/pbspark/__init__.py @@ -1,2 +1,6 @@ from ._proto import MessageConverter +from ._proto import df_from_protobuf +from ._proto import df_to_protobuf +from ._proto import from_protobuf +from ._proto import to_protobuf from ._version import __version__ diff --git a/pbspark/_proto.py b/pbspark/_proto.py index 82dc94d..6e645a2 100644 --- a/pbspark/_proto.py +++ b/pbspark/_proto.py @@ -10,7 +10,9 @@ from google.protobuf.message import Message from google.protobuf.timestamp_pb2 import Timestamp from pyspark.sql import Column +from pyspark.sql import DataFrame from pyspark.sql.functions import col +from pyspark.sql.functions import struct from pyspark.sql.functions import udf from pyspark.sql.types import ArrayType from pyspark.sql.types import BinaryType @@ -357,3 +359,109 @@ def to_protobuf( column = col(data) if isinstance(data, str) else data protobuf_encoder_udf = self.get_encoder_udf(message_type, options) return protobuf_encoder_udf(column) + + def df_from_protobuf( + self, + df: DataFrame, + message_type: t.Type[Message], + options: t.Optional[dict] = None, + expanded: bool = False, + ) -> DataFrame: + """Decode a dataframe of encoded protobuf. + + If expanded, return a dataframe in which each field is its own column. Otherwise + return a dataframe with a single struct column named `value`. + """ + df_decoded = df.select( + self.from_protobuf(df.columns[0], message_type, options).alias("value") + ) + if expanded: + df_decoded = df_decoded.select("value.*") + return df_decoded + + def df_to_protobuf( + self, + df: DataFrame, + message_type: t.Type[Message], + options: t.Optional[dict] = None, + expanded: bool = False, + ) -> DataFrame: + """Encode data in a dataframe to protobuf as column `value`. + + If `expanded`, the passed dataframe columns will be packed into a struct before + converting. Otherwise it is assumed that the dataframe passed is a single column + of data already packed into a struct. + + Returns a dataframe with a single column named `value` containing encoded data. + """ + if expanded: + df_struct = df.select( + struct([df[c] for c in df.columns]).alias("value") # type: ignore[arg-type] + ) + else: + df_struct = df.select(col(df.columns[0]).alias("value")) + df_encoded = df_struct.select( + self.to_protobuf(df_struct.value, message_type, options).alias("value") + ) + return df_encoded + + +def from_protobuf( + data: t.Union[Column, str], + message_type: t.Type[Message], + options: t.Optional[dict] = None, + mc: MessageConverter = None, +) -> Column: + """Deserialize protobuf messages to spark structs""" + mc = mc or MessageConverter() + return mc.from_protobuf(data=data, message_type=message_type, options=options) + + +def to_protobuf( + data: t.Union[Column, str], + message_type: t.Type[Message], + options: t.Optional[dict] = None, + mc: MessageConverter = None, +) -> Column: + """Serialize spark structs to protobuf messages.""" + mc = mc or MessageConverter() + return mc.to_protobuf(data=data, message_type=message_type, options=options) + + +def df_from_protobuf( + df: DataFrame, + message_type: t.Type[Message], + options: t.Optional[dict] = None, + expanded: bool = False, + mc: MessageConverter = None, +) -> DataFrame: + """Decode a dataframe of encoded protobuf. + + If expanded, return a dataframe in which each field is its own column. Otherwise + return a dataframe with a single struct column named `value`. + """ + mc = mc or MessageConverter() + return mc.df_from_protobuf( + df=df, message_type=message_type, options=options, expanded=expanded + ) + + +def df_to_protobuf( + df: DataFrame, + message_type: t.Type[Message], + options: t.Optional[dict] = None, + expanded: bool = False, + mc: MessageConverter = None, +) -> DataFrame: + """Encode data in a dataframe to protobuf as column `value`. + + If `expanded`, the passed dataframe columns will be packed into a struct before + converting. Otherwise it is assumed that the dataframe passed is a single column + of data already packed into a struct. + + Returns a dataframe with a single column named `value` containing encoded data. + """ + mc = mc or MessageConverter() + return mc.df_to_protobuf( + df=df, message_type=message_type, options=options, expanded=expanded + ) diff --git a/tests/test_proto.py b/tests/test_proto.py index 28f626c..0bc152e 100644 --- a/tests/test_proto.py +++ b/tests/test_proto.py @@ -30,6 +30,10 @@ from example.example_pb2 import RecursiveMessage from pbspark._proto import MessageConverter from pbspark._proto import _patched_convert_scalar_field_value +from pbspark._proto import df_from_protobuf +from pbspark._proto import df_to_protobuf +from pbspark._proto import from_protobuf +from pbspark._proto import to_protobuf from tests.fixtures import decimal_serializer # type: ignore[import] from tests.fixtures import encode_recursive @@ -67,6 +71,11 @@ def spark(): return spark +@pytest.fixture(params=[True, False]) +def expanded(request): + return request.param + + def test_get_spark_schema(): mc = MessageConverter() mc.register_serializer( @@ -251,3 +260,43 @@ def test_recursive_message(spark): dfs.show(truncate=False) data = dfs.collect() assert data[0].asDict(True)["value"] == expected + + +def test_to_from_protobuf(example, spark, expanded): + data = [{"value": example.SerializeToString()}] + + df = spark.createDataFrame(data) # type: ignore[type-var] + + df_decoded = df.select(from_protobuf(df.value, ExampleMessage).alias("value")) + + mc = MessageConverter() + assert df_decoded.schema.fields[0].dataType == mc.get_spark_schema(ExampleMessage) + + df_encoded = df_decoded.select( + to_protobuf(df_decoded.value, ExampleMessage).alias("value") + ) + + assert df_encoded.columns == ["value"] + assert df_encoded.schema == df.schema + assert df.collect() == df_encoded.collect() + + +def test_df_to_from_protobuf(example, spark, expanded): + data = [{"value": example.SerializeToString()}] + + df = spark.createDataFrame(data) # type: ignore[type-var] + + df_decoded = df_from_protobuf(df, ExampleMessage, expanded=expanded) + + mc = MessageConverter() + schema = mc.get_spark_schema(ExampleMessage) + if expanded: + assert df_decoded.schema == schema + else: + assert df_decoded.schema.fields[0].dataType == schema + + df_encoded = df_to_protobuf(df_decoded, ExampleMessage, expanded=expanded) + + assert df_encoded.columns == ["value"] + assert df_encoded.schema == df.schema + assert df.collect() == df_encoded.collect()