diff --git a/typing_extensions/src_py3/test_typing_extensions.py b/typing_extensions/src_py3/test_typing_extensions.py index e59d0da9..e39612c6 100644 --- a/typing_extensions/src_py3/test_typing_extensions.py +++ b/typing_extensions/src_py3/test_typing_extensions.py @@ -428,6 +428,9 @@ class Point2D(TypedDict): x: int y: int +class Point2Dor3D(Point2D, total=False): + z: int + class LabelPoint2D(Point2D, Label): ... class Options(TypedDict, total=False): @@ -442,7 +445,7 @@ class Options(TypedDict, total=False): ann_module = ann_module2 = ann_module3 = None A = B = CSub = G = CoolEmployee = CoolEmployeeWithDefault = object XMeth = XRepr = HasCallProtocol = NoneAndForward = Loop = object - Point2D = LabelPoint2D = Options = object + Point2D = Point2Dor3D = LabelPoint2D = Options = object gth = get_type_hints @@ -1481,7 +1484,7 @@ def test_typeddict_create_errors(self): def test_typeddict_errors(self): Emp = TypedDict('Emp', {'name': str, 'id': int}) - if hasattr(typing, 'TypedDict'): + if sys.version_info[:2] >= (3, 9): self.assertEqual(TypedDict.__module__, 'typing') else: self.assertEqual(TypedDict.__module__, 'typing_extensions') @@ -1543,6 +1546,11 @@ def test_total(self): self.assertEqual(Options(log_level=2), {'log_level': 2}) self.assertEqual(Options.__total__, False) + @skipUnless(PY36, 'Python 3.6 required') + def test_optional_keys(self): + assert Point2Dor3D.__required_keys__ == frozenset(['x', 'y']) + assert Point2Dor3D.__optional_keys__ == frozenset(['z']) + @skipUnless(TYPING_3_5_3, "Python >= 3.5.3 required") class AnnotatedTests(BaseTestCase): @@ -1817,7 +1825,14 @@ def test_typing_extensions_includes_standard(self): self.assertIn('runtime', a) def test_typing_extensions_defers_when_possible(self): - exclude = {'overload', 'Text', 'TYPE_CHECKING', 'Final', 'get_type_hints'} + exclude = { + 'overload', + 'Text', + 'TypedDict', + 'TYPE_CHECKING', + 'Final', + 'get_type_hints' + } for item in typing_extensions.__all__: if item not in exclude and hasattr(typing, item): self.assertIs( diff --git a/typing_extensions/src_py3/typing_extensions.py b/typing_extensions/src_py3/typing_extensions.py index b0e03f57..dceb35e4 100644 --- a/typing_extensions/src_py3/typing_extensions.py +++ b/typing_extensions/src_py3/typing_extensions.py @@ -1569,7 +1569,9 @@ def runtime_checkable(cls): runtime = runtime_checkable -if hasattr(typing, 'TypedDict'): +if sys.version_info[:2] >= (3, 9): + # The standard library TypedDict in Python 3.8 does not store runtime information + # about which (if any) keys are optional. See https://bugs.python.org/issue38834 TypedDict = typing.TypedDict else: def _check_fails(cls, other): @@ -1652,9 +1654,20 @@ def __new__(cls, name, bases, ns, total=True): anns = ns.get('__annotations__', {}) msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type" anns = {n: typing._type_check(tp, msg) for n, tp in anns.items()} + required = set(anns if total else ()) + optional = set(() if total else anns) + for base in bases: - anns.update(base.__dict__.get('__annotations__', {})) + base_anns = base.__dict__.get('__annotations__', {}) + anns.update(base_anns) + if getattr(base, '__total__', True): + required.update(base_anns) + else: + optional.update(base_anns) + tp_dict.__annotations__ = anns + tp_dict.__required_keys__ = frozenset(required) + tp_dict.__optional_keys__ = frozenset(optional) if not hasattr(tp_dict, '__total__'): tp_dict.__total__ = total return tp_dict @@ -1682,8 +1695,9 @@ class Point2D(TypedDict): assert Point2D(x=1, y=2, label='first') == dict(x=1, y=2, label='first') - The type info could be accessed via Point2D.__annotations__. TypedDict - supports two additional equivalent forms:: + The type info can be accessed via the Point2D.__annotations__ dict, and + the Point2D.__required_keys__ and Point2D.__optional_keys__ frozensets. + TypedDict supports two additional equivalent forms:: Point2D = TypedDict('Point2D', x=int, y=int, label=str) Point2D = TypedDict('Point2D', {'x': int, 'y': int, 'label': str})