Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

don't add redundant types to unions to prevent unnecessary unknown/any errors #447

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion packages/pyright-internal/src/analyzer/codeFlowEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -898,7 +898,8 @@ export function getCodeFlowEngine(
}
}

const effectiveType = typesToCombine.length > 0 ? combineTypes(typesToCombine) : undefined;
const effectiveType =
typesToCombine.length > 0 ? combineTypes(typesToCombine, undefined, evaluator) : undefined;

return setCacheEntry(branchNode, effectiveType, sawIncomplete);
}
Expand Down
2 changes: 1 addition & 1 deletion packages/pyright-internal/src/analyzer/operations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ export function getTypeOfBinaryOperation(
flags | EvaluatorFlags.ExpectingInstantiableType
);

let newUnion = combineTypes([adjustedLeftType, adjustedRightType]);
let newUnion = combineTypes([adjustedLeftType, adjustedRightType], undefined, evaluator);

const unionClass = evaluator.getUnionClassType();
if (unionClass && isInstantiableClass(unionClass)) {
Expand Down
112 changes: 98 additions & 14 deletions packages/pyright-internal/src/analyzer/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import { Uri } from '../common/uri/uri';
import { ArgumentNode, ExpressionNode, NameNode, ParameterCategory } from '../parser/parseNodes';
import { ClassDeclaration, FunctionDeclaration, SpecialBuiltInClassDeclaration } from './declaration';
import { Symbol, SymbolTable } from './symbol';
import { TypeEvaluator } from './typeEvaluatorTypes';
import { AssignTypeFlags } from './typeUtils';

export const enum TypeCategory {
// Name is not bound to a value of any type.
Expand Down Expand Up @@ -3237,7 +3239,6 @@ export function removeUnbound(type: Type): Type {

return type;
}

export function removeFromUnion(type: Type, removeFilter: (type: Type) => boolean) {
if (isUnion(type)) {
const remainingTypes = type.subtypes.filter((t) => !removeFilter(t));
Expand Down Expand Up @@ -3265,11 +3266,28 @@ export function findSubtype(type: Type, filter: (type: UnionableType | NeverType
return filter(type) ? type : undefined;
}

// Combines multiple types into a single type. If the types are
// the same, only one is returned. If they differ, they
// are combined into a UnionType. NeverTypes are filtered out.
// If no types remain in the end, a NeverType is returned.
export function combineTypes(subtypes: Type[], maxSubtypeCount?: number): Type {
/**
* Combines multiple types into a single type. If the types are
* the same, only one is returned. If they differ, they
* are combined into a UnionType. NeverTypes are filtered out.
* If no types remain in the end, a NeverType is returned.
*
* if a {@link TypeEvaluator} is provided, it not only checks that
* the types aren't the same, but also prevents redundant subtypes from
* being added to the union. eg. adding `Literal[1]` to a union of `int | str`
* is useless, so the union is left as-is. when adding a supertype to a union
* that contains a subtype of it, that subtype becomes redundant and therefore
* gets removed (eg. adding `int` to `Literal[1] | str` will result in
* `int | str`). this is useful to prevent cases where a narrowed type would be
* treated as partially unknown unnecessarily (eg. `object | list[Any]`).
*
* a {@link TypeEvaluator} should not be provided in cases where the union
* intentionally contains redundant information for the purpose of autocomplete.
* i don't think there are any situations where this is supported currently, but
* it's something to keep in mind if we end up implementing
* https://github.com/DetachHead/basedpyright/issues/320
*/
export function combineTypes(subtypes: Type[], maxSubtypeCount?: number, evaluator?: TypeEvaluator): Type {
// Filter out any "Never" and "NoReturn" types.
let sawNoReturn = false;

Expand Down Expand Up @@ -3352,21 +3370,87 @@ export function combineTypes(subtypes: Type[], maxSubtypeCount?: number): Type {
return UnknownType.create();
}

const newUnionType = UnionType.create();
let newUnionType = UnionType.create();
if (typeAliasSources.size > 0) {
newUnionType.typeAliasSources = typeAliasSources;
}

let hitMaxSubtypeCount = false;

expandedTypes.forEach((subtype, index) => {
if (index === 0) {
UnionType.addType(newUnionType, subtype as UnionableType);
} else {
if (maxSubtypeCount === undefined || newUnionType.subtypes.length < maxSubtypeCount) {
_addTypeIfUnique(newUnionType, subtype as UnionableType);
expandedTypes.forEach((subtype) => {
let shouldAddType = false;
if (
// if an evaluator isn't specified, don't do the redundant type check
!evaluator ||
// if it's a type var (including recursive type aliases which get synthesized into typevars),
// we don't know the bound type at this point so it's not safe to do the redundant type check
isTypeVar(subtype) ||
// no types have been added to the union yet which causes its type to be Never, which would break
// the redundant type check
!newUnionType.subtypes.length
) {
shouldAddType = true;
} else if (
// i cant figure out how to check whether a special form is assignable, for now we just skip the
// redundant check on special forms
subtype.specialForm ||
newUnionType.subtypes.find((subtype) => subtype.specialForm)
) {
shouldAddType = true;
} else if (
// if the new type is a subtype of a type that's already in the union, it's redundant and therefore
// does not need to be added to the union
!evaluator.assignType(
newUnionType,
subtype,
undefined,
undefined,
undefined,
AssignTypeFlags.OverloadOverlapCheck
)
) {
shouldAddType = true;
if (
// if the new type is a supertype of a type that's already in the union, we need to get rid of that
// type then replace it with the new wider one
evaluator.assignType(
subtype,
newUnionType,
undefined,
undefined,
undefined,
AssignTypeFlags.OverloadOverlapCheck
)
) {
const filteredType = removeFromUnion(newUnionType, (type) =>
evaluator.assignType(
subtype,
type,
undefined,
undefined,
undefined,
AssignTypeFlags.OverloadOverlapCheck
)
);
if (isUnion(filteredType)) {
newUnionType = filteredType;
} else {
newUnionType = UnionType.create();
if (filteredType.category !== TypeCategory.Never) {
UnionType.addType(newUnionType, filteredType as UnionableType);
}
}
}
}
if (shouldAddType) {
if (!newUnionType.subtypes.length) {
UnionType.addType(newUnionType, subtype as UnionableType);
} else {
hitMaxSubtypeCount = true;
if (maxSubtypeCount === undefined || newUnionType.subtypes.length < maxSubtypeCount) {
_addTypeIfUnique(newUnionType, subtype as UnionableType);
} else {
hitMaxSubtypeCount = true;
}
}
}
});
Expand Down
16 changes: 16 additions & 0 deletions packages/pyright-internal/src/tests/samples/typeNarrowingBased.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Any, assert_type


def foo(value: object):
print(value)
if isinstance(value, list):
_ = assert_type(value, list[Any])
_ = assert_type(value, object)

def bar(value: object):
print(value)
if isinstance(value, list):
_ = assert_type(value, list[Any])
else:
_ = assert_type(value, object)
_ = assert_type(value, object)
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __bool__(self) -> Literal[False]:

def func7(x: NoneProto | None):
if x is None:
reveal_type(x, expected_text="None")
reveal_type(x, expected_text="Never") # should be None. see https://github.com/DetachHead/basedpyright/issues/459
else:
reveal_type(x, expected_text="NoneProto")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ def func1(
if isinstance(obj, Callable):
reveal_type(obj, expected_text="((int, str) -> int) | B | TCall1@func1")
else:
reveal_type(obj, expected_text="list[int] | C | D | A")
reveal_type(obj, expected_text="list[int] | C | A")
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,12 @@ test('subscript context manager types on 3.8', () => {
],
});
});

test("useless type isn't added to union after if statement", () => {
const configOptions = new ConfigOptions(Uri.empty());
configOptions.diagnosticRuleSet.reportAssertTypeFailure = 'error';
const analysisResults = typeAnalyzeSampleFiles(['typeNarrowingBased.py'], configOptions);
validateResultsButBased(analysisResults, {
errors: [],
});
});
Loading