Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure any fields that could intersect with Message are renamed #321

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 26 additions & 36 deletions src/betterproto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,12 @@ class Message(ABC):
Calls :meth:`__bool__`.
"""

_serialized_on_wire: bool
serialized_on_wire: bool
"""
If this message was or should be serialized on the wire. This can be used to detect
presence (e.g. optional wrapper message) and is used internally during
parsing/serialization.
"""
_unknown_fields: bytes
_group_current: Dict[str, str]

Expand All @@ -634,7 +639,7 @@ def __post_init__(self) -> None:
group_current[meta.group] = field_name

# Now that all the defaults are set, reset it!
self.__dict__["_serialized_on_wire"] = not all_sentinel
self.__dict__["serialized_on_wire"] = not all_sentinel
self.__dict__["_unknown_fields"] = b""
self.__dict__["_group_current"] = group_current

Expand Down Expand Up @@ -694,9 +699,9 @@ def __getattribute__(self, name: str) -> Any:
return value

def __setattr__(self, attr: str, value: Any) -> None:
if attr != "_serialized_on_wire":
if attr != "serialized_on_wire":
# Track when a field has been set.
self.__dict__["_serialized_on_wire"] = True
self.__dict__["serialized_on_wire"] = True

if hasattr(self, "_group_current"): # __post_init__ had already run
if attr in self._betterproto.oneof_group_by_field:
Expand Down Expand Up @@ -756,7 +761,7 @@ def __bytes__(self) -> bytes:

# Empty messages can still be sent on the wire if they were
# set (or received empty).
serialize_empty = isinstance(value, Message) and value._serialized_on_wire
serialize_empty = isinstance(value, Message) and value.serialized_on_wire

include_default_value_for_oneof = self._include_default_value_for_oneof(
field_name=field_name, meta=meta
Expand Down Expand Up @@ -924,7 +929,7 @@ def _postprocess_single(
value = _get_wrapper(meta.wraps)().parse(value).value
else:
value = cls().parse(value)
value._serialized_on_wire = True
value.serialized_on_wire = True
elif meta.proto_type == TYPE_MAP:
value = self._betterproto.cls_by_field[field_name]().parse(value)

Expand Down Expand Up @@ -953,7 +958,7 @@ def parse(self: T, data: bytes) -> T:
The initialized message.
"""
# Got some data over the wire
self._serialized_on_wire = True
self.serialized_on_wire = True
proto_meta = self._betterproto
for parsed in parse_fields(data):
field_name = proto_meta.field_name_by_number.get(parsed.number)
Expand Down Expand Up @@ -1089,7 +1094,7 @@ def to_dict(
if include_default_values:
output[cased_name] = value
elif (
value._serialized_on_wire
value.serialized_on_wire
or include_default_values
or self._include_default_value_for_oneof(
field_name=field_name, meta=meta
Expand Down Expand Up @@ -1170,7 +1175,7 @@ def from_dict(self: T, value: Dict[str, Any]) -> T:
:class:`Message`
The initialized message.
"""
self._serialized_on_wire = True
self.serialized_on_wire = True
for key in value:
field_name = safe_snake_case(key)
meta = self._betterproto.meta_by_field_name.get(field_name)
Expand Down Expand Up @@ -1280,34 +1285,19 @@ def from_json(self: T, value: Union[str, bytes]) -> T:
"""
return self.from_dict(json.loads(value))

def which_one_of(self, group_name: str) -> Tuple[str, Optional[Any]]:
"""
Return the name and value of a message's one-of field group.

def serialized_on_wire(message: Message) -> bool:
"""
If this message was or should be serialized on the wire. This can be used to detect
presence (e.g. optional wrapper message) and is used internally during
parsing/serialization.

Returns
--------
:class:`bool`
Whether this message was or should be serialized on the wire.
"""
return message._serialized_on_wire


def which_one_of(message: Message, group_name: str) -> Tuple[str, Optional[Any]]:
"""
Return the name and value of a message's one-of field group.

Returns
--------
Tuple[:class:`str`, Any]
The field name and the value for that field.
"""
field_name = message._group_current.get(group_name)
if not field_name:
return "", None
return field_name, getattr(message, field_name)
Returns
--------
Tuple[:class:`str`, Any]
The field name and the value for that field.
"""
field_name = self._group_current.get(group_name)
if not field_name:
return "", None
return field_name, getattr(self, field_name)


# Circular import workaround: google.protobuf depends on base classes defined above.
Expand Down
12 changes: 9 additions & 3 deletions src/betterproto/plugin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

import builtins
import betterproto
from betterproto import which_one_of
from betterproto.casing import sanitize_name
from betterproto.compile.importing import (
get_type_reference,
Expand Down Expand Up @@ -114,6 +113,9 @@
FieldDescriptorProtoType.TYPE_SINT32, # 17
FieldDescriptorProtoType.TYPE_SINT64, # 18
)
UNSAFE_FIELD_NAMES = frozenset(dir(betterproto.Message)) | frozenset(
betterproto.Message.__annotations__
)


def monkey_patch_oneof_index():
Expand Down Expand Up @@ -355,7 +357,7 @@ def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
us to tell whether it was set, via the which_one_of interface.
"""

return which_one_of(proto_field_obj, "oneof_index")[0] == "oneof_index"
return proto_field_obj.which_one_of("oneof_index")[0] == "oneof_index"


@dataclass
Expand Down Expand Up @@ -501,7 +503,11 @@ def packed(self) -> bool:
@property
def py_name(self) -> str:
"""Pythonized name."""
return pythonize_field_name(self.proto_name)
unsafe_name = pythonize_field_name(self.proto_name)
# rename fields in case they clash with things defined in Message
if unsafe_name in UNSAFE_FIELD_NAMES:
return f"{unsafe_name}_"
return unsafe_name

@property
def proto_name(self) -> str:
Expand Down
5 changes: 2 additions & 3 deletions tests/inputs/oneof/test_oneof.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import betterproto
from tests.output_betterproto.oneof import Test
from tests.util import get_test_case_json_data


def test_which_count():
message = Test()
message.from_json(get_test_case_json_data("oneof")[0].json)
assert betterproto.which_one_of(message, "foo") == ("pitied", 100)
assert message.which_one_of("foo") == ("pitied", 100)


def test_which_name():
message = Test()
message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0].json)
assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T")
assert message.which_one_of("foo") == ("pitier", "Mr. T")
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
import datetime

import betterproto
from tests.output_betterproto.oneof_default_value_serialization import (
Test,
Message,
Expand All @@ -10,9 +9,9 @@


def assert_round_trip_serialization_works(message: Test) -> None:
assert betterproto.which_one_of(message, "value_type") == betterproto.which_one_of(
Test().from_json(message.to_json()), "value_type"
)
assert message.which_one_of("value_type") == Test().from_json(
message.to_json()
).which_one_of("value_type")


def test_oneof_default_value_serialization_works_for_all_values():
Expand Down Expand Up @@ -49,8 +48,8 @@ def test_oneof_default_value_serialization_works_for_all_values():
def test_oneof_no_default_values_passed():
message = Test()
assert (
betterproto.which_one_of(message, "value_type")
== betterproto.which_one_of(Test().from_json(message.to_json()), "value_type")
message.which_one_of("value_type")
== Test().from_json(message.to_json()).which_one_of("value_type")
== ("", None)
)

Expand All @@ -65,8 +64,8 @@ def test_oneof_nested_oneof_messages_are_serialized_with_defaults():
)
)
assert (
betterproto.which_one_of(message, "value_type")
== betterproto.which_one_of(Test().from_json(message.to_json()), "value_type")
message.which_one_of("value_type")
== Test().from_json(message.to_json()).which_one_of("value_type")
== (
"wrapped_nested_message_value",
NestedMessage(id=0, wrapped_message_value=Message(value=0)),
Expand Down
7 changes: 3 additions & 4 deletions tests/inputs/oneof_enum/test_oneof_enum.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest

import betterproto
from tests.output_betterproto.oneof_enum import (
Move,
Signal,
Expand All @@ -22,7 +21,7 @@ def test_which_one_of_returns_enum_with_default_value():
x=0, y=0
) # Proto3 will default this as there is no null
assert message.signal == Signal.PASS
assert betterproto.which_one_of(message, "action") == ("signal", Signal.PASS)
assert message.which_one_of("action") == ("signal", Signal.PASS)


def test_which_one_of_returns_enum_with_non_default_value():
Expand All @@ -37,12 +36,12 @@ def test_which_one_of_returns_enum_with_non_default_value():
x=0, y=0
) # Proto3 will default this as there is no null
assert message.signal == Signal.RESIGN
assert betterproto.which_one_of(message, "action") == ("signal", Signal.RESIGN)
assert message.which_one_of("action") == ("signal", Signal.RESIGN)


def test_which_one_of_returns_second_field_when_set():
message = Test()
message.from_json(get_test_case_json_data("oneof_enum")[0].json)
assert message.move == Move(x=2, y=3)
assert message.signal == Signal.PASS
assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3))
assert message.which_one_of("action") == ("move", Move(x=2, y=3))
9 changes: 9 additions & 0 deletions tests/inputs/rename/rename.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
syntax = "proto3";

// The fields that have overlapping names with betterproto.Message will be renamed.
message Test {
bool parse = 1;
bool serialized_on_wire = 2;
bool from_json = 3;
int32 this = 4;
}
12 changes: 12 additions & 0 deletions tests/inputs/rename/test_rename.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from dataclasses import fields

from tests.output_betterproto.rename import Test


def test_renamed_fields():
assert {field.name for field in fields(Test)} == {
"parse_",
"serialized_on_wire_",
"from_json_",
"this",
}
Loading