Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mypy/applytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,17 @@ def apply_generic_arguments(
assert isinstance(typ, TypeVarLikeType)
remaining_tvars.append(typ)

instance_type = None
if callable.instance_type is not None:
instance_type = expand_type(callable.instance_type, id_to_type)
assert isinstance(instance_type, ProperType)

return callable.copy_modified(
ret_type=expand_type(callable.ret_type, id_to_type),
variables=remaining_tvars,
type_guard=type_guard,
type_is=type_is,
instance_type=instance_type,
)


Expand Down
16 changes: 15 additions & 1 deletion mypy/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
from mypy_extensions import u8

# High-level cache layout format
CACHE_VERSION: Final = 8
CACHE_VERSION: Final = 9

# Type used internally to represent errors:
# (path, line, column, end_line, end_column, severity, message, code)
Expand Down Expand Up @@ -558,6 +558,20 @@ def write_json(data: WriteBuffer, value: dict[str, Any]) -> None:
write_json_value(data, value[key])


def write_flags(data: WriteBuffer, flags: list[bool]) -> None:
assert len(flags) <= 26, "This many flags not supported yet"
packed = 0
for i, flag in enumerate(flags):
if flag:
packed |= 1 << i
write_int(data, packed)


def read_flags(data: ReadBuffer, num_flags: int) -> list[bool]:
packed = read_int(data)
return [(packed & (1 << i)) != 0 for i in range(num_flags)]


def write_errors(data: WriteBuffer, errs: list[ErrorTuple]) -> None:
write_tag(data, LIST_GEN)
write_int_bare(data, len(errs))
Expand Down
2 changes: 1 addition & 1 deletion mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5515,7 +5515,7 @@ def check_except_handler_test(self, n: Expression, is_star: bool) -> Type:
if not item.is_type_obj():
self.fail(message_registry.INVALID_EXCEPTION_TYPE, n)
return self.default_exception_type(is_star)
exc_type = erase_typevars(item.ret_type)
exc_type = erase_typevars(item.get_instance_type())
elif isinstance(ttype, TypeType):
exc_type = ttype.item
else:
Expand Down
18 changes: 9 additions & 9 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@
from mypy.subtypes import (
covers_at_runtime,
find_member,
is_equivalent,
is_same_type,
is_subtype,
non_method_protocol_members,
Expand Down Expand Up @@ -689,7 +688,7 @@ def method_fullname(self, object_type: Type, method_name: str) -> str | None:
# For class method calls, object_type is a callable representing the class object.
# We "unwrap" it to a regular type, as the class/instance method difference doesn't
# affect the fully qualified name.
object_type = get_proper_type(object_type.ret_type)
object_type = object_type.get_instance_type()
elif isinstance(object_type, TypeType):
object_type = object_type.item

Expand Down Expand Up @@ -717,9 +716,9 @@ def always_returns_none(self, node: Expression) -> bool:
if isinstance(typ, Instance):
info = typ.type
elif isinstance(typ, CallableType) and typ.is_type_obj():
ret_type = get_proper_type(typ.ret_type)
if isinstance(ret_type, Instance):
info = ret_type.type
instance_type = typ.get_instance_type(force_fallback=True)
if isinstance(instance_type, Instance):
info = instance_type.type
else:
return False
else:
Expand Down Expand Up @@ -1668,9 +1667,10 @@ def check_callable_call(
callee = callee.with_unpacked_kwargs().with_normalized_var_args()
if callable_name is None and callee.name:
callable_name = callee.name
ret_type = get_proper_type(callee.ret_type)
if callee.is_type_obj() and isinstance(ret_type, Instance):
callable_name = ret_type.type.fullname
if callee.is_type_obj():
instance_type = callee.get_instance_type(force_fallback=True)
if isinstance(instance_type, Instance):
callable_name = instance_type.type.fullname
if isinstance(callable_node, RefExpr) and callable_node.fullname in ENUM_BASES:
# An Enum() call that failed SemanticAnalyzerPass2.check_enum_call().
return callee.ret_type, callee
Expand Down Expand Up @@ -1867,7 +1867,7 @@ def check_callable_call(
if (
callee.is_type_obj()
and (len(arg_types) == 1)
and is_equivalent(callee.ret_type, self.named_type("builtins.type"))
and is_named_instance(callee.get_instance_type(), "builtins.type")
):
callee = callee.copy_modified(ret_type=TypeType.make_normalized(arg_types[0]))

Expand Down
21 changes: 9 additions & 12 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,15 +407,8 @@ def validate_super_call(node: FuncBase, mx: MemberContext) -> None:
def analyze_type_callable_member_access(name: str, typ: FunctionLike, mx: MemberContext) -> Type:
# Class attribute.
# TODO super?
ret_type = typ.items[0].ret_type
assert isinstance(ret_type, ProperType)
if isinstance(ret_type, TupleType):
ret_type = tuple_fallback(ret_type)
if isinstance(ret_type, TypedDictType):
ret_type = ret_type.fallback
if isinstance(ret_type, LiteralType):
ret_type = ret_type.fallback
if isinstance(ret_type, Instance):
instance_type = typ.items[0].get_instance_type(force_fallback=True)
if isinstance(instance_type, Instance):
if not mx.is_operator:
# When Python sees an operator (eg `3 == 4`), it automatically translates that
# into something like `int.__eq__(3, 4)` instead of `(3).__eq__(4)` as an
Expand All @@ -432,14 +425,18 @@ def analyze_type_callable_member_access(name: str, typ: FunctionLike, mx: Member
# See https://github.com/python/mypy/pull/1787 for more info.
# TODO: do not rely on same type variables being present in all constructor overloads.
result = analyze_class_attribute_access(
ret_type, name, mx, original_vars=typ.items[0].variables, mcs_fallback=typ.fallback
instance_type,
name,
mx,
original_vars=typ.items[0].variables,
mcs_fallback=typ.fallback,
)
if result:
return result
# Look up from the 'type' type.
return _analyze_member_access(name, typ.fallback, mx)
else:
assert False, f"Unexpected type {ret_type!r}"
assert False, f"Unexpected type {instance_type!r}"


def analyze_type_type_member_access(
Expand Down Expand Up @@ -721,7 +718,7 @@ def analyze_descriptor_access(descriptor_type: Type, mx: MemberContext) -> Type:
dunder_get_type = expand_type_by_instance(bound_method, typ)

if isinstance(instance_type, FunctionLike) and instance_type.is_type_obj():
owner_type = instance_type.items[0].ret_type
owner_type = instance_type.items[0].get_instance_type()
instance_type = NoneType()
elif isinstance(instance_type, TypeType):
owner_type = instance_type.item
Expand Down
59 changes: 47 additions & 12 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import mypy.typeops
from mypy.argmap import ArgTypeExpander
from mypy.erasetype import erase_typevars
from mypy.expandtype import expand_type_by_instance
from mypy.maptype import map_instance_to_supertype
from mypy.nodes import (
ARG_OPT,
Expand Down Expand Up @@ -275,7 +276,11 @@ def infer_constraints_for_callable(


def infer_constraints(
template: Type, actual: Type, direction: int, skip_neg_op: bool = False
template: Type,
actual: Type,
direction: int,
skip_neg_op: bool = False,
erase_types: bool = True,
) -> list[Constraint]:
"""Infer type constraints.

Expand Down Expand Up @@ -312,14 +317,14 @@ def infer_constraints(
# Return early on an empty branch.
return []
type_state.inferring.append((template, actual))
res = _infer_constraints(template, actual, direction, skip_neg_op)
res = _infer_constraints(template, actual, direction, skip_neg_op, erase_types)
type_state.inferring.pop()
return res
return _infer_constraints(template, actual, direction, skip_neg_op)
return _infer_constraints(template, actual, direction, skip_neg_op, erase_types)


def _infer_constraints(
template: Type, actual: Type, direction: int, skip_neg_op: bool
template: Type, actual: Type, direction: int, skip_neg_op: bool, erase_types: bool
) -> list[Constraint]:
orig_template = template
template = get_proper_type(template)
Expand Down Expand Up @@ -424,7 +429,7 @@ def _infer_constraints(
return []

# Remaining cases are handled by ConstraintBuilderVisitor.
return template.accept(ConstraintBuilderVisitor(actual, direction, skip_neg_op))
return template.accept(ConstraintBuilderVisitor(actual, direction, skip_neg_op, erase_types))


def _is_type_type(tp: ProperType) -> TypeGuard[TypeType | UnionType]:
Expand Down Expand Up @@ -659,14 +664,20 @@ class ConstraintBuilderVisitor(TypeVisitor[list[Constraint]]):
# TODO: The value may be None. Is that actually correct?
actual: ProperType

def __init__(self, actual: ProperType, direction: int, skip_neg_op: bool) -> None:
def __init__(
self, actual: ProperType, direction: int, skip_neg_op: bool, erase_types: bool
) -> None:
# Direction must be SUBTYPE_OF or SUPERTYPE_OF.
self.actual = actual
self.direction = direction
# Whether to skip polymorphic inference (involves inference in opposite direction)
# this is used to prevent infinite recursion when both template and actual are
# generic callables.
self.skip_neg_op = skip_neg_op
# Normally we should erase generic actual type when inferring against type[T]
# to avoid leaking type variables, see testGenericClassAsArgumentToType.
# The only exception is self-types in generic classes, where we set this to False.
self.erase_types = erase_types

# Trivial leaf types

Expand Down Expand Up @@ -759,13 +770,11 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
and template.type.is_protocol
and self.direction == SUPERTYPE_OF
):
ret_type = get_proper_type(actual.ret_type)
if isinstance(ret_type, TupleType):
ret_type = mypy.typeops.tuple_fallback(ret_type)
if isinstance(ret_type, Instance):
instance_type = actual.get_instance_type(force_fallback=True)
if isinstance(instance_type, Instance):
res.extend(
self.infer_constraints_from_protocol_members(
ret_type, template, ret_type, template, class_obj=True
instance_type, template, instance_type, template, class_obj=True
)
)
actual = actual.fallback
Expand Down Expand Up @@ -1213,6 +1222,20 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
elif isinstance(self.actual, Overloaded):
return self.infer_against_overloaded(self.actual, template)
elif isinstance(self.actual, TypeType):
# This matches the corresponding logic in subtypes.py.
item = self.actual.item
if isinstance(item, TupleType):
item = mypy.typeops.tuple_fallback(item)
if isinstance(item, Instance):
constructor = mypy.typeops.type_object_type(item.type)
constructor = expand_type_by_instance(constructor, item)
# Only consider return type to match historic behavior (see below).
if isinstance(constructor, CallableType):
return infer_constraints(
template.ret_type, constructor.ret_type, self.direction
)
elif isinstance(constructor, Overloaded):
return self.infer_against_overloaded(constructor, template, ret_only=True)
return infer_constraints(template.ret_type, self.actual.item, self.direction)
elif isinstance(self.actual, Instance):
# Instances with __call__ method defined are considered structural
Expand All @@ -1228,14 +1251,16 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
return []

def infer_against_overloaded(
self, overloaded: Overloaded, template: CallableType
self, overloaded: Overloaded, template: CallableType, ret_only: bool = False
) -> list[Constraint]:
# Create constraints by matching an overloaded type against a template.
# This is tricky to do in general. We cheat by only matching against
# the first overload item that is callable compatible. This
# seems to work somewhat well, but we should really use a more
# reliable technique.
item = find_matching_overload_item(overloaded, template)
if ret_only:
return infer_constraints(template.ret_type, item.ret_type, self.direction)
return infer_constraints(template, item, self.direction)

def visit_tuple_type(self, template: TupleType) -> list[Constraint]:
Expand Down Expand Up @@ -1398,8 +1423,18 @@ def visit_overloaded(self, template: Overloaded) -> list[Constraint]:

def visit_type_type(self, template: TypeType) -> list[Constraint]:
if isinstance(self.actual, CallableType):
if self.actual.is_type_obj():
instance_type = self.actual.get_instance_type()
if self.erase_types:
instance_type = erase_typevars(instance_type)
return infer_constraints(template.item, instance_type, self.direction)
return infer_constraints(template.item, self.actual.ret_type, self.direction)
elif isinstance(self.actual, Overloaded):
if self.actual.is_type_obj():
instance_type = self.actual.items[0].get_instance_type()
if self.erase_types:
instance_type = erase_typevars(instance_type)
return infer_constraints(template.item, instance_type, self.direction)
return infer_constraints(template.item, self.actual.items[0].ret_type, self.direction)
elif isinstance(self.actual, TypeType):
return infer_constraints(template.item, self.actual.item, self.direction)
Expand Down
9 changes: 7 additions & 2 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,11 +485,16 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
arg_types = self.interpolate_args_for_unpack(t, var_arg.typ)
else:
arg_types = self.expand_types(t.arg_types)
instance_type = None
if t.instance_type is not None:
instance_type = t.instance_type.accept(self)
assert isinstance(instance_type, ProperType)
expanded = t.copy_modified(
arg_types=arg_types,
ret_type=t.ret_type.accept(self),
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
type_is=(t.type_is.accept(self) if t.type_is is not None else None),
type_guard=t.type_guard.accept(self) if t.type_guard is not None else None,
type_is=t.type_is.accept(self) if t.type_is is not None else None,
instance_type=instance_type,
)
if needs_normalization:
return expanded.with_normalized_var_args()
Expand Down
2 changes: 2 additions & 0 deletions mypy/fixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ def visit_callable_type(self, ct: CallableType) -> None:
ct.type_guard.accept(self)
if ct.type_is is not None:
ct.type_is.accept(self)
if ct.instance_type is not None:
ct.instance_type.accept(self)

def visit_overloaded(self, t: Overloaded) -> None:
for ct in t.items:
Expand Down
6 changes: 6 additions & 0 deletions mypy/indirection.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ def visit_instance(self, t: types.Instance) -> None:
def visit_callable_type(self, t: types.CallableType) -> None:
self._visit_type_list(t.arg_types)
self._visit(t.ret_type)
if t.type_guard is not None:
self._visit(t.type_guard)
if t.type_is is not None:
self._visit(t.type_is)
if t.instance_type is not None:
self._visit(t.instance_type)
self._visit_type_tuple(t.variables)

def visit_overloaded(self, t: types.Overloaded) -> None:
Expand Down
5 changes: 4 additions & 1 deletion mypy/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,11 @@ def infer_type_arguments(
actual: Type,
is_supertype: bool = False,
skip_unsatisfied: bool = False,
erase_types: bool = True,
) -> list[Type | None]:
# Like infer_function_type_arguments, but only match a single type
# against a generic type.
constraints = infer_constraints(template, actual, SUPERTYPE_OF if is_supertype else SUBTYPE_OF)
constraints = infer_constraints(
template, actual, SUPERTYPE_OF if is_supertype else SUBTYPE_OF, erase_types=erase_types
)
return solve_constraints(type_vars, constraints, skip_unsatisfied=skip_unsatisfied)[0]
8 changes: 8 additions & 0 deletions mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,10 +773,14 @@ def join_similar_callables(t: CallableType, s: CallableType) -> CallableType:
fallback = t.fallback
else:
fallback = s.fallback
instance_type = None
if t.instance_type is not None and s.instance_type is not None:
instance_type = join_types(t.instance_type, s.instance_type)
return t.copy_modified(
arg_types=arg_types,
arg_names=combine_arg_names(t, s),
ret_type=join_types(t.ret_type, s.ret_type),
instance_type=instance_type,
fallback=fallback,
name=None,
)
Expand Down Expand Up @@ -827,10 +831,14 @@ def combine_similar_callables(t: CallableType, s: CallableType) -> CallableType:
fallback = t.fallback
else:
fallback = s.fallback
instance_type = None
if t.instance_type is not None and s.instance_type is not None:
instance_type = join_types(t.instance_type, s.instance_type)
return t.copy_modified(
arg_types=arg_types,
arg_names=combine_arg_names(t, s),
ret_type=join_types(t.ret_type, s.ret_type),
instance_type=instance_type,
fallback=fallback,
name=None,
)
Expand Down
Loading
Loading