Skip to content

Commit

Permalink
Un-generify call_infer and use a CallArg enum instead
Browse files Browse the repository at this point in the history
Summary: This lets us reduce call_method_generic, call_method, and call_method_with_types down to a single call_method function. It'll also allow me to prepend a "self" argument to an arguments list without worrying about whether the other arguments are expressions or types.

Reviewed By: ndmitchell

Differential Revision: D67188826

fbshipit-source-id: ba9dd6970bcd8c03975687ad7decfb3967694ec5
  • Loading branch information
rchen152 authored and facebook-github-bot committed Dec 13, 2024
1 parent d9853af commit 8a77318
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 98 deletions.
55 changes: 27 additions & 28 deletions pyre2/pyre2/bin/alt/answers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use starlark_map::ordered_set::OrderedSet;
use starlark_map::small_map::Entry;
use starlark_map::small_map::SmallMap;

use crate::alt::expr::TypeCallArg;
use crate::alt::expr::CallArg;
use crate::ast::Ast;
use crate::binding::binding::Binding;
use crate::binding::binding::BindingAnnotation;
Expand Down Expand Up @@ -634,12 +634,12 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
match iterable {
Type::ClassType(cls) => {
let ty = if self.has_attribute(cls.class_object(), &dunder::ITER) {
let iterator_ty =
self.call_method_with_types(iterable, &dunder::ITER, range, &[]);
self.call_method_with_types(&iterator_ty, &dunder::NEXT, range, &[])
let iterator_ty = self.call_method(iterable, &dunder::ITER, range, &[], &[]);
self.call_method(&iterator_ty, &dunder::NEXT, range, &[], &[])
} else if self.has_attribute(cls.class_object(), &dunder::GETITEM) {
let arg = TypeCallArg::new(self.stdlib.int().to_type(), range);
self.call_method_with_types(iterable, &dunder::GETITEM, range, &[arg])
let int_ty = self.stdlib.int().to_type();
let arg = CallArg::Type(&int_ty, range);
self.call_method(iterable, &dunder::GETITEM, range, &[arg], &[])
} else {
self.error(range, format!("Class `{}` is not iterable", cls.name()))
};
Expand Down Expand Up @@ -774,13 +774,14 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
) -> Type {
match kind {
ContextManagerKind::Sync => {
self.call_method_with_types(context_manager_type, &dunder::ENTER, range, &[])
self.call_method(context_manager_type, &dunder::ENTER, range, &[], &[])
}
ContextManagerKind::Async => match self.unwrap_awaitable(&self.call_method_with_types(
ContextManagerKind::Async => match self.unwrap_awaitable(&self.call_method(
context_manager_type,
&dunder::AENTER,
range,
&[],
&[],
)) {
Some(ty) => ty,
None => self.error(range, format!("Expected `{}` to be async", dunder::AENTER)),
Expand All @@ -795,32 +796,28 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
range: TextRange,
) -> Type {
let base_exception_class_type = Type::type_form(self.stdlib.base_exception().to_type());
let arg1 = Type::Union(vec![base_exception_class_type, Type::None]);
let arg2 = Type::Union(vec![self.stdlib.base_exception().to_type(), Type::None]);
let arg3 = Type::Union(vec![self.stdlib.traceback_type().to_type(), Type::None]);
let exit_arg_types = [
TypeCallArg::new(
Type::Union(vec![base_exception_class_type, Type::None]),
range,
),
TypeCallArg::new(
Type::Union(vec![self.stdlib.base_exception().to_type(), Type::None]),
range,
),
TypeCallArg::new(
Type::Union(vec![self.stdlib.traceback_type().to_type(), Type::None]),
range,
),
CallArg::Type(&arg1, range),
CallArg::Type(&arg2, range),
CallArg::Type(&arg3, range),
];
match kind {
ContextManagerKind::Sync => self.call_method_with_types(
ContextManagerKind::Sync => self.call_method(
context_manager_type,
&dunder::EXIT,
range,
&exit_arg_types,
&[],
),
ContextManagerKind::Async => match self.unwrap_awaitable(&self.call_method_with_types(
ContextManagerKind::Async => match self.unwrap_awaitable(&self.call_method(
context_manager_type,
&dunder::AEXIT,
range,
&exit_arg_types,
&[],
)) {
Some(ty) => ty,
None => self.error(range, format!("Expected `{}` to be async", dunder::AEXIT)),
Expand Down Expand Up @@ -923,7 +920,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
&base,
&inplace_dunder(x.op),
x.range,
&[*x.value.clone()],
&[CallArg::Expr(&x.value)],
&[],
)
}
Expand All @@ -950,15 +947,16 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
let base = self.expr(&x.value, None);
let slice_ty = self.expr(&x.slice, None);
let value_ty = self.solve_binding_inner(b);
self.call_method_with_types(
self.call_method(
&base,
&dunder::SETITEM,
x.range,
&[
TypeCallArg::new(slice_ty, x.slice.range()),
CallArg::Type(&slice_ty, x.slice.range()),
// use the subscript's location
TypeCallArg::new(value_ty, x.range),
CallArg::Type(&value_ty, x.range),
],
&[],
)
}
Binding::UnpackedValue(b, range, pos) => {
Expand Down Expand Up @@ -1237,12 +1235,13 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
// TODO: check against duplicate keys (optional)
let key_ty = self.expr(mapping_key, None);
let binding_ty = self.get_idx(*binding_key).arc_clone();
let arg = TypeCallArg::new(key_ty, mapping_key.range());
self.call_method_with_types(
let arg = CallArg::Type(&key_ty, mapping_key.range());
self.call_method(
&binding_ty,
&dunder::GETITEM,
mapping_key.range(),
&[arg],
&[],
)
}
Binding::PatternMatchClassPositional(_, idx, key, range) => {
Expand Down
110 changes: 40 additions & 70 deletions pyre2/pyre2/bin/alt/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
* LICENSE file in the root directory of this source tree.
*/

use std::slice;

use dupe::Dupe;
use ruff_python_ast::name::Name;
use ruff_python_ast::Arguments;
Expand Down Expand Up @@ -55,22 +53,40 @@ enum CallStyle<'a> {
FreeForm,
}

/// This struct bundles a `Type` with a `TextRange`, allowing us to typecheck function calls
/// when we only know the types of the arguments but not the original expressions.
pub struct TypeCallArg {
ty: Type,
range: TextRange,
#[derive(Clone)]
pub enum CallArg<'a> {
/// Bundles a `Type` with a `TextRange`, allowing us to typecheck function calls
/// when we only know the types of the arguments but not the original expressions.
Type(&'a Type, TextRange),
Expr(&'a Expr),
}

impl Ranged for TypeCallArg {
impl Ranged for CallArg<'_> {
fn range(&self) -> TextRange {
self.range
match self {
Self::Type(_, r) => *r,
Self::Expr(e) => e.range(),
}
}
}

impl TypeCallArg {
pub fn new(ty: Type, range: TextRange) -> Self {
Self { ty, range }
impl CallArg<'_> {
/// Check an argument against the type hint (if any) on the corresponding parameter
fn check_against_hint<Ans: LookupAnswer>(
&self,
answers: &AnswersSolver<Ans>,
hint: Option<&Type>,
) {
match self {
Self::Type(ty, r) => {
if let Some(hint) = hint {
answers.check_type(hint, ty, *r);
}
}
Self::Expr(e) => {
answers.expr(e, hint);
}
}
}
}

Expand Down Expand Up @@ -192,15 +208,13 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
}
}

fn call_method_generic<T: Ranged>(
pub fn call_method(
&self,
ty: &Type,
method_name: &Name,
range: TextRange,
args: &[T],
args: &[CallArg],
keywords: &[Keyword],
// Function to check an argument against the type hint (if any) on the corresponding parameter
check_arg: &dyn Fn(&T, Option<&Type>),
) -> Type {
self.distribute_over_union(ty, |ty| {
let callable = match self
Expand All @@ -210,39 +224,10 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
Ok(ty) => self.as_call_target_or_error(ty, CallStyle::Method(method_name), range),
Err(msg) => self.error_call_target(range, msg),
};
self.call_infer(callable, args, keywords, check_arg, range)
self.call_infer(callable, args, keywords, range)
})
}

pub fn call_method(
&self,
ty: &Type,
method_name: &Name,
range: TextRange,
args: &[Expr],
keywords: &[Keyword],
) -> Type {
let check_arg = &|arg: &Expr, hint: Option<&Type>| {
self.expr(arg, hint);
};
self.call_method_generic(ty, method_name, range, args, keywords, check_arg)
}

pub fn call_method_with_types(
&self,
ty: &Type,
method_name: &Name,
range: TextRange,
args: &[TypeCallArg],
) -> Type {
let check_arg = &|arg: &TypeCallArg, hint: Option<&Type>| {
if let Some(hint) = hint {
self.check_type(hint, &arg.ty, arg.range);
}
};
self.call_method_generic(ty, method_name, range, args, &[], check_arg)
}

pub fn expr(&self, x: &Expr, check: Option<&Type>) -> Type {
match check {
Some(want) if !want.is_any() => {
Expand All @@ -263,12 +248,11 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
}
}

fn call_infer<T: Ranged>(
fn call_infer(
&self,
call_target: CallTarget,
args: &[T],
args: &[CallArg],
keywords: &[Keyword],
check_arg: &dyn Fn(&T, Option<&Type>),
range: TextRange,
) -> Type {
match call_target {
Expand All @@ -281,7 +265,6 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
),
args,
keywords,
check_arg,
range,
);
cls.self_type()
Expand Down Expand Up @@ -328,7 +311,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
);
num_positional = -1;
}
check_arg(arg, hint);
arg.check_against_hint(self, hint);
}
let mut need_positional = 0;
let mut kwparams = SmallMap::new();
Expand Down Expand Up @@ -408,7 +391,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
Args::Ellipsis => {
// Deal with Callable[..., R]
for t in args {
check_arg(t, None);
t.check_against_hint(self, None);
}
callable.ret
}
Expand All @@ -435,17 +418,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
let method_type = self.attr_infer(lhs, &Name::new(op.dunder()), range);
let callable =
self.as_call_target_or_error(method_type, CallStyle::BinaryOp(op), range);
self.call_infer(
callable,
&[TypeCallArg::new(rhs, range)],
&[],
&|arg, hint| {
if let Some(hint) = hint {
self.check_type(hint, &arg.ty, arg.range);
}
},
range,
)
self.call_infer(callable, &[CallArg::Type(&rhs, range)], &[], range)
};
let lhs = self.expr_infer(&x.left);
let rhs = self.expr_infer(&x.right);
Expand Down Expand Up @@ -791,11 +764,8 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
);
self.call_infer(
callable,
&x.arguments.args,
&x.arguments.args.map(CallArg::Expr),
&x.arguments.keywords,
&|arg, hint| {
self.expr(arg, hint);
},
func_range,
)
})
Expand Down Expand Up @@ -934,7 +904,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
&Type::Tuple(Tuple::Concrete(elts)),
&dunder::GETITEM,
x.range,
slice::from_ref(&x.slice),
&[CallArg::Expr(&x.slice)],
&[],
),
}
Expand All @@ -944,15 +914,15 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
&Type::Tuple(Tuple::Unbounded(elt)),
&dunder::GETITEM,
x.range,
slice::from_ref(&x.slice),
&[CallArg::Expr(&x.slice)],
&[],
),
Type::Any(style) => style.propagate(),
Type::ClassType(_) => self.call_method(
&fun,
&dunder::GETITEM,
x.range,
slice::from_ref(&x.slice),
&[CallArg::Expr(&x.slice)],
&[],
),
t => self.error(
Expand Down

0 comments on commit 8a77318

Please sign in to comment.