Skip to content

Commit

Permalink
Changed the type of an async def function from typing.Coroutine t…
Browse files Browse the repository at this point in the history
…o `types.CoroutineType` for improved type accuracy. (#9850)
  • Loading branch information
erictraut authored Feb 7, 2025
1 parent 08918ec commit c667546
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 13 deletions.
9 changes: 6 additions & 3 deletions packages/pyright-internal/src/analyzer/checker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,10 @@ export class Checker extends ParseTreeWalker {
node
);

if (isClassInstance(returnType) && ClassType.isBuiltIn(returnType, 'Coroutine')) {
if (
isClassInstance(returnType) &&
ClassType.isBuiltIn(returnType, ['Coroutine', 'CoroutineType'])
) {
this._evaluator.addDiagnostic(
DiagnosticRule.reportUnusedCoroutine,
LocMessage.unusedCoroutine(),
Expand Down Expand Up @@ -1064,7 +1067,7 @@ export class Checker extends ParseTreeWalker {
let yieldType: Type | undefined;
let sendType: Type | undefined;

if (isClassInstance(yieldFromType) && ClassType.isBuiltIn(yieldFromType, 'Coroutine')) {
if (isClassInstance(yieldFromType) && ClassType.isBuiltIn(yieldFromType, ['Coroutine', 'CoroutineType'])) {
// Handle the case of old-style (pre-await) coroutines.
yieldType = UnknownType.create();
} else {
Expand Down Expand Up @@ -1842,7 +1845,7 @@ export class Checker extends ParseTreeWalker {
isExprFunction = false;
}

if (!isClassInstance(subtype) || !ClassType.isBuiltIn(subtype, 'Coroutine')) {
if (!isClassInstance(subtype) || !ClassType.isBuiltIn(subtype, ['Coroutine', 'CoroutineType'])) {
isCoroutine = false;
}
});
Expand Down
4 changes: 2 additions & 2 deletions packages/pyright-internal/src/analyzer/codeFlowEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1846,7 +1846,7 @@ export function getCodeFlowEngine(
if (returnType) {
if (
isClassInstance(returnType) &&
ClassType.isBuiltIn(returnType, 'Coroutine') &&
ClassType.isBuiltIn(returnType, ['Coroutine', 'CoroutineType']) &&
returnType.priv.typeArgs &&
returnType.priv.typeArgs.length >= 3
) {
Expand Down Expand Up @@ -1955,7 +1955,7 @@ export function getCodeFlowEngine(
if (isAsync) {
if (
isClassInstance(returnType) &&
ClassType.isBuiltIn(returnType, 'Coroutine') &&
ClassType.isBuiltIn(returnType, ['Coroutine', 'CoroutineType']) &&
returnType.priv.typeArgs &&
returnType.priv.typeArgs.length >= 3
) {
Expand Down
11 changes: 7 additions & 4 deletions packages/pyright-internal/src/analyzer/typeEvaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14668,7 +14668,10 @@ export function createTypeEvaluator(
}

// Handle old-style (pre-await) Coroutines as a special case.
if (isClassInstance(yieldFromSubtype) && ClassType.isBuiltIn(yieldFromSubtype, 'Coroutine')) {
if (
isClassInstance(yieldFromSubtype) &&
ClassType.isBuiltIn(yieldFromSubtype, ['Coroutine', 'CoroutineType'])
) {
return UnknownType.create();
}

Expand Down Expand Up @@ -19313,8 +19316,8 @@ export function createTypeEvaluator(
}

if (!awaitableReturnType || !isGenerator) {
// Wrap in either an Awaitable or a Coroutine, which is a subclass of Awaitable.
const awaitableType = getTypingType(node, useCoroutine ? 'Coroutine' : 'Awaitable');
// Wrap in either an Awaitable or a CoroutineType, which is a subclass of Awaitable.
const awaitableType = useCoroutine ? getTypesType(node, 'CoroutineType') : getTypingType(node, 'Awaitable');
if (awaitableType && isInstantiableClass(awaitableType)) {
awaitableReturnType = ClassType.cloneAsInstance(
ClassType.specialize(
Expand Down Expand Up @@ -19462,7 +19465,7 @@ export function createTypeEvaluator(
const iteratorTypeResult = getTypeOfExpression(yieldNode.d.expr);
if (
isClassInstance(iteratorTypeResult.type) &&
ClassType.isBuiltIn(iteratorTypeResult.type, 'Coroutine')
ClassType.isBuiltIn(iteratorTypeResult.type, ['Coroutine', 'CoroutineType'])
) {
const yieldType =
iteratorTypeResult.type.priv.typeArgs &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
//// y = [|/*marker2*/test|]
helper.verifyHover('markdown', {
marker1: '```python\n(function) async def test() -> None\n```',
marker2: '```python\n(function) def test() -> Coroutine[Any, Any, None]\n```',
marker2: '```python\n(function) def test() -> CoroutineType[Any, Any, None]\n```',
});
2 changes: 1 addition & 1 deletion packages/pyright-internal/src/tests/samples/callSite2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
reveal_type(v3, expected_text="Unknown")

v4 = async_call(1)
reveal_type(v4, expected_text="Coroutine[Any, Any, Unknown]")
reveal_type(v4, expected_text="CoroutineType[Any, Any, Unknown]")
4 changes: 2 additions & 2 deletions packages/pyright-internal/src/tests/samples/generator13.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_generator4() -> AsyncGenerator[int, None]:

async def demo_bug1() -> None:
v1 = get_generator1()
reveal_type(v1, expected_text="Coroutine[Any, Any, AsyncGenerator[str, None]]")
reveal_type(v1, expected_text="CoroutineType[Any, Any, AsyncGenerator[str, None]]")
gen = await v1
reveal_type(gen, expected_text="AsyncGenerator[str, None]")
async for s in gen:
Expand All @@ -53,7 +53,7 @@ async def demo_bug1() -> None:

async def demo_bug2() -> None:
v1 = get_generator2()
reveal_type(v1, expected_text="Coroutine[Any, Any, AsyncIterator[str]]")
reveal_type(v1, expected_text="CoroutineType[Any, Any, AsyncIterator[str]]")
gen = await v1
reveal_type(gen, expected_text="AsyncIterator[str]")
async for s in gen:
Expand Down

0 comments on commit c667546

Please sign in to comment.