From 8a77318bee242b500c683ec86f917594335f62d6 Mon Sep 17 00:00:00 2001 From: "Rebecca Chen (Python)" Date: Fri, 13 Dec 2024 02:07:45 -0800 Subject: [PATCH] Un-generify call_infer and use a CallArg enum instead Summary: This lets us reduce call_method_generic, call_method, and call_method_with_types down to a single call_method function. It'll also allow me to prepend a "self" argument to an arguments list without worrying about whether the other arguments are expressions or types. Reviewed By: ndmitchell Differential Revision: D67188826 fbshipit-source-id: ba9dd6970bcd8c03975687ad7decfb3967694ec5 --- pyre2/pyre2/bin/alt/answers.rs | 55 ++++++++--------- pyre2/pyre2/bin/alt/expr.rs | 110 ++++++++++++--------------------- 2 files changed, 67 insertions(+), 98 deletions(-) diff --git a/pyre2/pyre2/bin/alt/answers.rs b/pyre2/pyre2/bin/alt/answers.rs index 7b40ef664d0..95dee94a9fd 100644 --- a/pyre2/pyre2/bin/alt/answers.rs +++ b/pyre2/pyre2/bin/alt/answers.rs @@ -21,7 +21,7 @@ use starlark_map::ordered_set::OrderedSet; use starlark_map::small_map::Entry; use starlark_map::small_map::SmallMap; -use crate::alt::expr::TypeCallArg; +use crate::alt::expr::CallArg; use crate::ast::Ast; use crate::binding::binding::Binding; use crate::binding::binding::BindingAnnotation; @@ -634,12 +634,12 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { match iterable { Type::ClassType(cls) => { let ty = if self.has_attribute(cls.class_object(), &dunder::ITER) { - let iterator_ty = - self.call_method_with_types(iterable, &dunder::ITER, range, &[]); - self.call_method_with_types(&iterator_ty, &dunder::NEXT, range, &[]) + let iterator_ty = self.call_method(iterable, &dunder::ITER, range, &[], &[]); + self.call_method(&iterator_ty, &dunder::NEXT, range, &[], &[]) } else if self.has_attribute(cls.class_object(), &dunder::GETITEM) { - let arg = TypeCallArg::new(self.stdlib.int().to_type(), range); - self.call_method_with_types(iterable, &dunder::GETITEM, range, &[arg]) + let int_ty = self.stdlib.int().to_type(); + let arg = CallArg::Type(&int_ty, range); + self.call_method(iterable, &dunder::GETITEM, range, &[arg], &[]) } else { self.error(range, format!("Class `{}` is not iterable", cls.name())) }; @@ -774,13 +774,14 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { ) -> Type { match kind { ContextManagerKind::Sync => { - self.call_method_with_types(context_manager_type, &dunder::ENTER, range, &[]) + self.call_method(context_manager_type, &dunder::ENTER, range, &[], &[]) } - ContextManagerKind::Async => match self.unwrap_awaitable(&self.call_method_with_types( + ContextManagerKind::Async => match self.unwrap_awaitable(&self.call_method( context_manager_type, &dunder::AENTER, range, &[], + &[], )) { Some(ty) => ty, None => self.error(range, format!("Expected `{}` to be async", dunder::AENTER)), @@ -795,32 +796,28 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { range: TextRange, ) -> Type { let base_exception_class_type = Type::type_form(self.stdlib.base_exception().to_type()); + let arg1 = Type::Union(vec![base_exception_class_type, Type::None]); + let arg2 = Type::Union(vec![self.stdlib.base_exception().to_type(), Type::None]); + let arg3 = Type::Union(vec![self.stdlib.traceback_type().to_type(), Type::None]); let exit_arg_types = [ - TypeCallArg::new( - Type::Union(vec![base_exception_class_type, Type::None]), - range, - ), - TypeCallArg::new( - Type::Union(vec![self.stdlib.base_exception().to_type(), Type::None]), - range, - ), - TypeCallArg::new( - Type::Union(vec![self.stdlib.traceback_type().to_type(), Type::None]), - range, - ), + CallArg::Type(&arg1, range), + CallArg::Type(&arg2, range), + CallArg::Type(&arg3, range), ]; match kind { - ContextManagerKind::Sync => self.call_method_with_types( + ContextManagerKind::Sync => self.call_method( context_manager_type, &dunder::EXIT, range, &exit_arg_types, + &[], ), - ContextManagerKind::Async => match self.unwrap_awaitable(&self.call_method_with_types( + ContextManagerKind::Async => match self.unwrap_awaitable(&self.call_method( context_manager_type, &dunder::AEXIT, range, &exit_arg_types, + &[], )) { Some(ty) => ty, None => self.error(range, format!("Expected `{}` to be async", dunder::AEXIT)), @@ -923,7 +920,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { &base, &inplace_dunder(x.op), x.range, - &[*x.value.clone()], + &[CallArg::Expr(&x.value)], &[], ) } @@ -950,15 +947,16 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let base = self.expr(&x.value, None); let slice_ty = self.expr(&x.slice, None); let value_ty = self.solve_binding_inner(b); - self.call_method_with_types( + self.call_method( &base, &dunder::SETITEM, x.range, &[ - TypeCallArg::new(slice_ty, x.slice.range()), + CallArg::Type(&slice_ty, x.slice.range()), // use the subscript's location - TypeCallArg::new(value_ty, x.range), + CallArg::Type(&value_ty, x.range), ], + &[], ) } Binding::UnpackedValue(b, range, pos) => { @@ -1237,12 +1235,13 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { // TODO: check against duplicate keys (optional) let key_ty = self.expr(mapping_key, None); let binding_ty = self.get_idx(*binding_key).arc_clone(); - let arg = TypeCallArg::new(key_ty, mapping_key.range()); - self.call_method_with_types( + let arg = CallArg::Type(&key_ty, mapping_key.range()); + self.call_method( &binding_ty, &dunder::GETITEM, mapping_key.range(), &[arg], + &[], ) } Binding::PatternMatchClassPositional(_, idx, key, range) => { diff --git a/pyre2/pyre2/bin/alt/expr.rs b/pyre2/pyre2/bin/alt/expr.rs index db3628253b6..b1940121efd 100644 --- a/pyre2/pyre2/bin/alt/expr.rs +++ b/pyre2/pyre2/bin/alt/expr.rs @@ -5,8 +5,6 @@ * LICENSE file in the root directory of this source tree. */ -use std::slice; - use dupe::Dupe; use ruff_python_ast::name::Name; use ruff_python_ast::Arguments; @@ -55,22 +53,40 @@ enum CallStyle<'a> { FreeForm, } -/// This struct bundles a `Type` with a `TextRange`, allowing us to typecheck function calls -/// when we only know the types of the arguments but not the original expressions. -pub struct TypeCallArg { - ty: Type, - range: TextRange, +#[derive(Clone)] +pub enum CallArg<'a> { + /// Bundles a `Type` with a `TextRange`, allowing us to typecheck function calls + /// when we only know the types of the arguments but not the original expressions. + Type(&'a Type, TextRange), + Expr(&'a Expr), } -impl Ranged for TypeCallArg { +impl Ranged for CallArg<'_> { fn range(&self) -> TextRange { - self.range + match self { + Self::Type(_, r) => *r, + Self::Expr(e) => e.range(), + } } } -impl TypeCallArg { - pub fn new(ty: Type, range: TextRange) -> Self { - Self { ty, range } +impl CallArg<'_> { + /// Check an argument against the type hint (if any) on the corresponding parameter + fn check_against_hint( + &self, + answers: &AnswersSolver, + hint: Option<&Type>, + ) { + match self { + Self::Type(ty, r) => { + if let Some(hint) = hint { + answers.check_type(hint, ty, *r); + } + } + Self::Expr(e) => { + answers.expr(e, hint); + } + } } } @@ -192,15 +208,13 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } } - fn call_method_generic( + pub fn call_method( &self, ty: &Type, method_name: &Name, range: TextRange, - args: &[T], + args: &[CallArg], keywords: &[Keyword], - // Function to check an argument against the type hint (if any) on the corresponding parameter - check_arg: &dyn Fn(&T, Option<&Type>), ) -> Type { self.distribute_over_union(ty, |ty| { let callable = match self @@ -210,39 +224,10 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { Ok(ty) => self.as_call_target_or_error(ty, CallStyle::Method(method_name), range), Err(msg) => self.error_call_target(range, msg), }; - self.call_infer(callable, args, keywords, check_arg, range) + self.call_infer(callable, args, keywords, range) }) } - pub fn call_method( - &self, - ty: &Type, - method_name: &Name, - range: TextRange, - args: &[Expr], - keywords: &[Keyword], - ) -> Type { - let check_arg = &|arg: &Expr, hint: Option<&Type>| { - self.expr(arg, hint); - }; - self.call_method_generic(ty, method_name, range, args, keywords, check_arg) - } - - pub fn call_method_with_types( - &self, - ty: &Type, - method_name: &Name, - range: TextRange, - args: &[TypeCallArg], - ) -> Type { - let check_arg = &|arg: &TypeCallArg, hint: Option<&Type>| { - if let Some(hint) = hint { - self.check_type(hint, &arg.ty, arg.range); - } - }; - self.call_method_generic(ty, method_name, range, args, &[], check_arg) - } - pub fn expr(&self, x: &Expr, check: Option<&Type>) -> Type { match check { Some(want) if !want.is_any() => { @@ -263,12 +248,11 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } } - fn call_infer( + fn call_infer( &self, call_target: CallTarget, - args: &[T], + args: &[CallArg], keywords: &[Keyword], - check_arg: &dyn Fn(&T, Option<&Type>), range: TextRange, ) -> Type { match call_target { @@ -281,7 +265,6 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { ), args, keywords, - check_arg, range, ); cls.self_type() @@ -328,7 +311,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { ); num_positional = -1; } - check_arg(arg, hint); + arg.check_against_hint(self, hint); } let mut need_positional = 0; let mut kwparams = SmallMap::new(); @@ -408,7 +391,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { Args::Ellipsis => { // Deal with Callable[..., R] for t in args { - check_arg(t, None); + t.check_against_hint(self, None); } callable.ret } @@ -435,17 +418,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let method_type = self.attr_infer(lhs, &Name::new(op.dunder()), range); let callable = self.as_call_target_or_error(method_type, CallStyle::BinaryOp(op), range); - self.call_infer( - callable, - &[TypeCallArg::new(rhs, range)], - &[], - &|arg, hint| { - if let Some(hint) = hint { - self.check_type(hint, &arg.ty, arg.range); - } - }, - range, - ) + self.call_infer(callable, &[CallArg::Type(&rhs, range)], &[], range) }; let lhs = self.expr_infer(&x.left); let rhs = self.expr_infer(&x.right); @@ -791,11 +764,8 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { ); self.call_infer( callable, - &x.arguments.args, + &x.arguments.args.map(CallArg::Expr), &x.arguments.keywords, - &|arg, hint| { - self.expr(arg, hint); - }, func_range, ) }) @@ -934,7 +904,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { &Type::Tuple(Tuple::Concrete(elts)), &dunder::GETITEM, x.range, - slice::from_ref(&x.slice), + &[CallArg::Expr(&x.slice)], &[], ), } @@ -944,7 +914,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { &Type::Tuple(Tuple::Unbounded(elt)), &dunder::GETITEM, x.range, - slice::from_ref(&x.slice), + &[CallArg::Expr(&x.slice)], &[], ), Type::Any(style) => style.propagate(), @@ -952,7 +922,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { &fun, &dunder::GETITEM, x.range, - slice::from_ref(&x.slice), + &[CallArg::Expr(&x.slice)], &[], ), t => self.error(