Skip to content

Commit

Permalink
Track optional TypdeDict keys (#687)
Browse files Browse the repository at this point in the history
Backport of python/cpython#17214 (BPO-38834)
  • Loading branch information
Zac-HD authored and ilevkivskyi committed Nov 24, 2019
1 parent 024c81c commit 7ce3093
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
21 changes: 18 additions & 3 deletions typing_extensions/src_py3/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,9 @@ class Point2D(TypedDict):
x: int
y: int
class Point2Dor3D(Point2D, total=False):
z: int
class LabelPoint2D(Point2D, Label): ...
class Options(TypedDict, total=False):
Expand All @@ -442,7 +445,7 @@ class Options(TypedDict, total=False):
ann_module = ann_module2 = ann_module3 = None
A = B = CSub = G = CoolEmployee = CoolEmployeeWithDefault = object
XMeth = XRepr = HasCallProtocol = NoneAndForward = Loop = object
Point2D = LabelPoint2D = Options = object
Point2D = Point2Dor3D = LabelPoint2D = Options = object

gth = get_type_hints

Expand Down Expand Up @@ -1481,7 +1484,7 @@ def test_typeddict_create_errors(self):

def test_typeddict_errors(self):
Emp = TypedDict('Emp', {'name': str, 'id': int})
if hasattr(typing, 'TypedDict'):
if sys.version_info[:2] >= (3, 9):
self.assertEqual(TypedDict.__module__, 'typing')
else:
self.assertEqual(TypedDict.__module__, 'typing_extensions')
Expand Down Expand Up @@ -1543,6 +1546,11 @@ def test_total(self):
self.assertEqual(Options(log_level=2), {'log_level': 2})
self.assertEqual(Options.__total__, False)

@skipUnless(PY36, 'Python 3.6 required')
def test_optional_keys(self):
assert Point2Dor3D.__required_keys__ == frozenset(['x', 'y'])
assert Point2Dor3D.__optional_keys__ == frozenset(['z'])


@skipUnless(TYPING_3_5_3, "Python >= 3.5.3 required")
class AnnotatedTests(BaseTestCase):
Expand Down Expand Up @@ -1817,7 +1825,14 @@ def test_typing_extensions_includes_standard(self):
self.assertIn('runtime', a)

def test_typing_extensions_defers_when_possible(self):
exclude = {'overload', 'Text', 'TYPE_CHECKING', 'Final', 'get_type_hints'}
exclude = {
'overload',
'Text',
'TypedDict',
'TYPE_CHECKING',
'Final',
'get_type_hints'
}
for item in typing_extensions.__all__:
if item not in exclude and hasattr(typing, item):
self.assertIs(
Expand Down
22 changes: 18 additions & 4 deletions typing_extensions/src_py3/typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,7 +1569,9 @@ def runtime_checkable(cls):
runtime = runtime_checkable


if hasattr(typing, 'TypedDict'):
if sys.version_info[:2] >= (3, 9):
# The standard library TypedDict in Python 3.8 does not store runtime information
# about which (if any) keys are optional. See https://bugs.python.org/issue38834
TypedDict = typing.TypedDict
else:
def _check_fails(cls, other):
Expand Down Expand Up @@ -1652,9 +1654,20 @@ def __new__(cls, name, bases, ns, total=True):
anns = ns.get('__annotations__', {})
msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type"
anns = {n: typing._type_check(tp, msg) for n, tp in anns.items()}
required = set(anns if total else ())
optional = set(() if total else anns)

for base in bases:
anns.update(base.__dict__.get('__annotations__', {}))
base_anns = base.__dict__.get('__annotations__', {})
anns.update(base_anns)
if getattr(base, '__total__', True):
required.update(base_anns)
else:
optional.update(base_anns)

tp_dict.__annotations__ = anns
tp_dict.__required_keys__ = frozenset(required)
tp_dict.__optional_keys__ = frozenset(optional)
if not hasattr(tp_dict, '__total__'):
tp_dict.__total__ = total
return tp_dict
Expand Down Expand Up @@ -1682,8 +1695,9 @@ class Point2D(TypedDict):
assert Point2D(x=1, y=2, label='first') == dict(x=1, y=2, label='first')
The type info could be accessed via Point2D.__annotations__. TypedDict
supports two additional equivalent forms::
The type info can be accessed via the Point2D.__annotations__ dict, and
the Point2D.__required_keys__ and Point2D.__optional_keys__ frozensets.
TypedDict supports two additional equivalent forms::
Point2D = TypedDict('Point2D', x=int, y=int, label=str)
Point2D = TypedDict('Point2D', {'x': int, 'y': int, 'label': str})
Expand Down

0 comments on commit 7ce3093

Please sign in to comment.