Skip to content

Commit

Permalink
df_to/from_protobuf functions (#25)
Browse files Browse the repository at this point in the history
* add top level conversion functions

* flatten to expand

* better docs

* rm custom word

* clear return var

* update readme

* fix note

* cleanup

* add to/from protobuf funcs

* simplify test
  • Loading branch information
crflynn authored Jun 13, 2022
1 parent 9444cc1 commit 174e075
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 16 deletions.
102 changes: 86 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# pbspark

This package provides a way to convert protobuf messages into pyspark dataframes and vice versa using a pyspark udf.
This package provides a way to convert protobuf messages into pyspark dataframes and vice versa using pyspark `udf`s.

## Installation

Expand All @@ -26,7 +26,68 @@ message SimpleMessage {
}
```

Using `pbspark` we can decode the messages into spark `StructType` and then flatten them.
### Basic conversion functions

There are two functions for operating on columns, `to_protobuf` and `from_protobuf`. These operations convert to/from an encoded protobuf column to a column of a struct representing the inferred message structure. `MessageConverter` instances (discussed below) can optionally be passed to these functions.

```python
from pyspark.sql.session import SparkSession
from example.example_pb2 import SimpleMessage
from pbspark import from_protobuf
from pbspark import to_protobuf

spark = SparkSession.builder.getOrCreate()

example = SimpleMessage(name="hello", quantity=5, measure=12.3)
data = [{"value": example.SerializeToString()}]
df_encoded = spark.createDataFrame(data)

df_decoded = df_encoded.select(from_protobuf(df_encoded.value, SimpleMessage).alias("value"))
df_expanded = df_decoded.select("value.*")
df_expanded.show()

# +-----+--------+-------+
# | name|quantity|measure|
# +-----+--------+-------+
# |hello| 5| 12.3|
# +-----+--------+-------+

df_reencoded = df_decoded.select(to_protobuf(df_decoded.value, SimpleMessage).alias("value"))
```

There are two helper functions, `df_to_protobuf` and `df_from_protobuf` for use on dataframes. They have a kwarg `expanded`, which will also take care of expanding/contracting the data between the single `value` column used in these examples and a dataframe which contains a column for each message field. `MessageConverter` instances (discussed below) can optionally be passed to these functions.

```python
from pyspark.sql.session import SparkSession
from example.example_pb2 import SimpleMessage
from pbspark import df_from_protobuf
from pbspark import df_to_protobuf

spark = SparkSession.builder.getOrCreate()

example = SimpleMessage(name="hello", quantity=5, measure=12.3)
data = [{"value": example.SerializeToString()}]
df_encoded = spark.createDataFrame(data)

# expanded=True will perform a `.select("value.*")` after converting,
# resulting in each protobuf field having its own column
df_expanded = df_from_protobuf(df_encoded, SimpleMessage, expanded=True)
df_expanded.show()

# +-----+--------+-------+
# | name|quantity|measure|
# +-----+--------+-------+
# |hello| 5| 12.3|
# +-----+--------+-------+

# expanded=True will first pack data using `struct([df[c] for c in df.columns])`,
# use this if the passed dataframe is already expanded
df_reencoded = df_to_protobuf(df_expanded, SimpleMessage, expanded=True)
```

### Column conversion using the `MessageConverter`

The four helper functions above are also available as methods on the `MessageConverter` class. Using an instance of `MessageConverter` we can decode the column of encoded messages into a column of spark `StructType` and then expand the fields.

```python
from pyspark.sql.session import SparkSession
Expand All @@ -37,47 +98,50 @@ spark = SparkSession.builder.getOrCreate()

example = SimpleMessage(name="hello", quantity=5, measure=12.3)
data = [{"value": example.SerializeToString()}]
df = spark.createDataFrame(data)
df_encoded = spark.createDataFrame(data)

mc = MessageConverter()
df_decoded = df.select(mc.from_protobuf(df.value, SimpleMessage).alias("value"))
df_flattened = df_decoded.select("value.*")
df_flattened.show()
df_decoded = df_encoded.select(mc.from_protobuf(df_encoded.value, SimpleMessage).alias("value"))
df_expanded = df_decoded.select("value.*")
df_expanded.show()

# +-----+--------+-------+
# | name|quantity|measure|
# +-----+--------+-------+
# |hello| 5| 12.3|
# +-----+--------+-------+

df_flattened.schema
df_expanded.schema
# StructType(List(StructField(name,StringType,true),StructField(quantity,IntegerType,true),StructField(measure,FloatType,true))
```

We can also re-encode them into protobuf strings.
We can also re-encode them into protobuf.

```python
df_reencoded = df_decoded.select(mc.to_protobuf(df_decoded.value, SimpleMessage).alias("value"))
```

For flattened data, we can also (re-)encode after collecting and packing into a struct:
For expanded data, we can also encode after packing into a struct column:

```python
from pyspark.sql.functions import struct

df_unflattened = df_flattened.select(
struct([df_flattened[c] for c in df_flattened.columns]).alias("value")
df_unexpanded = df_expanded.select(
struct([df_expanded[c] for c in df_expanded.columns]).alias("value")
)
df_unflattened.show()
df_reencoded = df_unflattened.select(
mc.to_protobuf(df_unflattened.value, SimpleMessage).alias("value")
df_reencoded = df_unexpanded.select(
mc.to_protobuf(df_unexpanded.value, SimpleMessage).alias("value")
)
```

### Conversion details

Internally, `pbspark` uses protobuf's `MessageToDict`, which deserializes everything into JSON compatible objects by default. The exceptions are
* protobuf's bytes type, which `MessageToDict` would decode to a base64-encoded string; `pbspark` will decode any bytes fields directly to a spark `BinaryType`.
* protobuf's well known type, Timestamp type, which `MessageToDict` would decode to a string; `pbspark` will decode any Timestamp messages directly to a spark `TimestampType` (via python datetime objects).

### Custom conversion of message types

Custom serde is also supported. Suppose we use our `NestedMessage` from the repository's example and we want to serialize the key and value together into a single string.

```protobuf
Expand Down Expand Up @@ -115,9 +179,9 @@ spark = SparkSession(sc).builder.getOrCreate()

message = ExampleMessage(nested=NestedMessage(key="hello", value="world"))
data = [{"value": message.SerializeToString()}]
df = spark.createDataFrame(data)
df_encoded = spark.createDataFrame(data)

df_decoded = df.select(mc.from_protobuf(df.value, ExampleMessage).alias("value"))
df_decoded = df_encoded.select(mc.from_protobuf(df_encoded.value, ExampleMessage).alias("value"))
# rather than a struct the value of `nested` is a string
df_decoded.select("value.nested").show()

Expand All @@ -128,6 +192,8 @@ df_decoded.select("value.nested").show()
# +-----------+
```

### How to write conversion functions

More generally, custom serde functions should be written in the following format.

```python
Expand All @@ -146,3 +212,7 @@ def decode_nested(s: str, message: NestedMessage, path: str):
message.key = key
message.value = value
```

### Known issues

`RecursionError` when using self-referencing protobuf messages. Spark schemas do not allow for arbitrary depth, so protobuf messages which are circular- or self-referencing will result in infinite recursion errors when inferring the schema. If you have message structures like this you should resort to creating custom conversion functions, which forcibly limit the structural depth when converting these messages.
4 changes: 4 additions & 0 deletions pbspark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
from ._proto import MessageConverter
from ._proto import df_from_protobuf
from ._proto import df_to_protobuf
from ._proto import from_protobuf
from ._proto import to_protobuf
from ._version import __version__
108 changes: 108 additions & 0 deletions pbspark/_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from google.protobuf.message import Message
from google.protobuf.timestamp_pb2 import Timestamp
from pyspark.sql import Column
from pyspark.sql import DataFrame
from pyspark.sql.functions import col
from pyspark.sql.functions import struct
from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType
from pyspark.sql.types import BinaryType
Expand Down Expand Up @@ -357,3 +359,109 @@ def to_protobuf(
column = col(data) if isinstance(data, str) else data
protobuf_encoder_udf = self.get_encoder_udf(message_type, options)
return protobuf_encoder_udf(column)

def df_from_protobuf(
self,
df: DataFrame,
message_type: t.Type[Message],
options: t.Optional[dict] = None,
expanded: bool = False,
) -> DataFrame:
"""Decode a dataframe of encoded protobuf.
If expanded, return a dataframe in which each field is its own column. Otherwise
return a dataframe with a single struct column named `value`.
"""
df_decoded = df.select(
self.from_protobuf(df.columns[0], message_type, options).alias("value")
)
if expanded:
df_decoded = df_decoded.select("value.*")
return df_decoded

def df_to_protobuf(
self,
df: DataFrame,
message_type: t.Type[Message],
options: t.Optional[dict] = None,
expanded: bool = False,
) -> DataFrame:
"""Encode data in a dataframe to protobuf as column `value`.
If `expanded`, the passed dataframe columns will be packed into a struct before
converting. Otherwise it is assumed that the dataframe passed is a single column
of data already packed into a struct.
Returns a dataframe with a single column named `value` containing encoded data.
"""
if expanded:
df_struct = df.select(
struct([df[c] for c in df.columns]).alias("value") # type: ignore[arg-type]
)
else:
df_struct = df.select(col(df.columns[0]).alias("value"))
df_encoded = df_struct.select(
self.to_protobuf(df_struct.value, message_type, options).alias("value")
)
return df_encoded


def from_protobuf(
data: t.Union[Column, str],
message_type: t.Type[Message],
options: t.Optional[dict] = None,
mc: MessageConverter = None,
) -> Column:
"""Deserialize protobuf messages to spark structs"""
mc = mc or MessageConverter()
return mc.from_protobuf(data=data, message_type=message_type, options=options)


def to_protobuf(
data: t.Union[Column, str],
message_type: t.Type[Message],
options: t.Optional[dict] = None,
mc: MessageConverter = None,
) -> Column:
"""Serialize spark structs to protobuf messages."""
mc = mc or MessageConverter()
return mc.to_protobuf(data=data, message_type=message_type, options=options)


def df_from_protobuf(
df: DataFrame,
message_type: t.Type[Message],
options: t.Optional[dict] = None,
expanded: bool = False,
mc: MessageConverter = None,
) -> DataFrame:
"""Decode a dataframe of encoded protobuf.
If expanded, return a dataframe in which each field is its own column. Otherwise
return a dataframe with a single struct column named `value`.
"""
mc = mc or MessageConverter()
return mc.df_from_protobuf(
df=df, message_type=message_type, options=options, expanded=expanded
)


def df_to_protobuf(
df: DataFrame,
message_type: t.Type[Message],
options: t.Optional[dict] = None,
expanded: bool = False,
mc: MessageConverter = None,
) -> DataFrame:
"""Encode data in a dataframe to protobuf as column `value`.
If `expanded`, the passed dataframe columns will be packed into a struct before
converting. Otherwise it is assumed that the dataframe passed is a single column
of data already packed into a struct.
Returns a dataframe with a single column named `value` containing encoded data.
"""
mc = mc or MessageConverter()
return mc.df_to_protobuf(
df=df, message_type=message_type, options=options, expanded=expanded
)
49 changes: 49 additions & 0 deletions tests/test_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
from example.example_pb2 import RecursiveMessage
from pbspark._proto import MessageConverter
from pbspark._proto import _patched_convert_scalar_field_value
from pbspark._proto import df_from_protobuf
from pbspark._proto import df_to_protobuf
from pbspark._proto import from_protobuf
from pbspark._proto import to_protobuf
from tests.fixtures import decimal_serializer # type: ignore[import]
from tests.fixtures import encode_recursive

Expand Down Expand Up @@ -67,6 +71,11 @@ def spark():
return spark


@pytest.fixture(params=[True, False])
def expanded(request):
return request.param


def test_get_spark_schema():
mc = MessageConverter()
mc.register_serializer(
Expand Down Expand Up @@ -251,3 +260,43 @@ def test_recursive_message(spark):
dfs.show(truncate=False)
data = dfs.collect()
assert data[0].asDict(True)["value"] == expected


def test_to_from_protobuf(example, spark, expanded):
data = [{"value": example.SerializeToString()}]

df = spark.createDataFrame(data) # type: ignore[type-var]

df_decoded = df.select(from_protobuf(df.value, ExampleMessage).alias("value"))

mc = MessageConverter()
assert df_decoded.schema.fields[0].dataType == mc.get_spark_schema(ExampleMessage)

df_encoded = df_decoded.select(
to_protobuf(df_decoded.value, ExampleMessage).alias("value")
)

assert df_encoded.columns == ["value"]
assert df_encoded.schema == df.schema
assert df.collect() == df_encoded.collect()


def test_df_to_from_protobuf(example, spark, expanded):
data = [{"value": example.SerializeToString()}]

df = spark.createDataFrame(data) # type: ignore[type-var]

df_decoded = df_from_protobuf(df, ExampleMessage, expanded=expanded)

mc = MessageConverter()
schema = mc.get_spark_schema(ExampleMessage)
if expanded:
assert df_decoded.schema == schema
else:
assert df_decoded.schema.fields[0].dataType == schema

df_encoded = df_to_protobuf(df_decoded, ExampleMessage, expanded=expanded)

assert df_encoded.columns == ["value"]
assert df_encoded.schema == df.schema
assert df.collect() == df_encoded.collect()

0 comments on commit 174e075

Please sign in to comment.