diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index c3924b126cc7..7b65b51c13c6 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -2965,61 +2965,6 @@ def test_mixed_device_error_message(self): with self.assertRaisesRegex(ValueError, expected_error_message): flex_attention(query, key, value) - @supported_platform - def test_captured_wrong_device_error_message(self): - means = torch.randn(64, 3).cuda() - length_scales = torch.logspace(0.001, 0.1, 8) - - def euclidean_dist_pos_embed(score, b, h, q_idx, k_idx): - q_pos = means[q_idx] - k_pos = means[k_idx] - dist = (q_pos - k_pos).pow(2).sum(-1).sqrt() - scale = length_scales[h] - inv_dist = torch.exp(-dist / scale) - return inv_dist * score - - expected_error_message = "Buffers cannot be created" - - q, k, v = (torch.randn(1, 8, 64, 64, device="cuda") for _ in range(3)) - with self.assertRaisesRegex(RuntimeError, expected_error_message): - torch.compile(flex_attention)(q, k, v, score_mod=euclidean_dist_pos_embed) - - @supported_platform - def test_cant_lower_error_message(self): - # We can't lower a 256-element reduction inside a pointwise reduction - means = torch.randn(64, 256).cuda() - length_scales = torch.logspace(0.001, 0.1, 8).cuda() - - def euclidean_dist_pos_embed(score, b, h, q_idx, k_idx): - q_pos = means[q_idx] - k_pos = means[k_idx] - dist = (q_pos - k_pos).pow(2).sum(-1).sqrt() - scale = length_scales[h] - inv_dist = torch.exp(-dist / scale) - return inv_dist * score - - expected_error_message = "Buffers cannot be created" - - q, k, v = (torch.randn(1, 8, 64, 64, device="cuda") for _ in range(3)) - with self.assertRaisesRegex(RuntimeError, expected_error_message): - torch.compile(flex_attention)(q, k, v, score_mod=euclidean_dist_pos_embed) - - @supported_platform - def test_reduction_unrolled(self): - # We can't lower a 256-element reduction inside a pointwise reduction - means = torch.randn(S, 3).cuda() - length_scales = torch.logspace(0.001, 0.1, H).cuda() - - def euclidean_dist_pos_embed(score, b, h, q_idx, k_idx): - q_pos = means[q_idx] - k_pos = means[k_idx] - dist = (q_pos - k_pos).pow(2).sum(-1).sqrt() - scale = length_scales[h] - inv_dist = torch.exp(-dist / scale) - return inv_dist * score - - self.run_test(euclidean_dist_pos_embed, torch.bfloat16) - @supported_platform def test_invalid_block_size(self): # Create tensors on different devices diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index ca554b7cf43e..ddcd7462c704 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -443,10 +443,6 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "atol": 3e-4, "rtol": 0.002, }, - ("nn.functional.triplet_margin_with_distance_loss", f16): { - "atol": 3e-4, - "rtol": 0.003, - }, ("softmax", f16): {"atol": 1e-4, "rtol": 0.02}, ("polygamma.polygamma_n_0", f32): {"atol": 1e-3, "rtol": 1e-4}, ("polygamma.polygamma_n_1", f32): {"atol": 1e-3, "rtol": 1e-4}, diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index 25d9f2f1d88c..932d7440ab4c 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -410,8 +410,8 @@ def flex_attention_functionalize( assert isinstance(mask_mod_other_buffers_unwrapped, tuple) example_vals = ( - [query_unwrapped.new_zeros(())] - + [query_unwrapped.new_zeros((), dtype=torch.int) for _ in range(4)] + [torch.zeros((), dtype=query.dtype)] + + [torch.zeros((), dtype=torch.int) for _ in range(4)] + list(score_mod_other_buffers_unwrapped) ) with ctx.redispatch_to_next() as m: @@ -710,11 +710,11 @@ def flex_attention_autograd( input_requires_grad = any(t.requires_grad for t in (query, key, value)) if torch.is_grad_enabled() and input_requires_grad: example_vals = ( - query.new_zeros((), requires_grad=input_requires_grad), - query.new_zeros((), dtype=torch.int), - query.new_zeros((), dtype=torch.int), - query.new_zeros((), dtype=torch.int), - query.new_zeros((), dtype=torch.int), + torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad), + torch.zeros((), dtype=torch.int), + torch.zeros((), dtype=torch.int), + torch.zeros((), dtype=torch.int), + torch.zeros((), dtype=torch.int), ) fw_graph, bw_graph = create_fw_bw_graph( score_mod, example_vals, score_mod_other_buffers @@ -930,11 +930,11 @@ def trace_flex_attention_backward( mask_mod_other_buffers, ) - fw_example_vals = [query.new_zeros((), requires_grad=query.requires_grad)] + [ - query.new_zeros((), dtype=torch.int) for _ in range(4) - ] - bw_example_vals = fw_example_vals + [query.new_zeros(())] - mask_example_vals = [query.new_zeros((), dtype=torch.int) for _ in range(4)] + fw_example_vals = [ + torch.zeros((), dtype=query.dtype, requires_grad=query.requires_grad) + ] + [torch.zeros((), dtype=torch.int) for _ in range(4)] + bw_example_vals = fw_example_vals + [torch.zeros((), dtype=query.dtype)] + mask_example_vals = [torch.zeros((), dtype=torch.int) for _ in range(4)] mask_graph = block_mask[-1] with TransformGetItemToIndex(): fw_graph = reenter_make_fx(fw_graph)(*fw_example_vals, *score_mod_other_buffers) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 1ef697a22a38..8c42710318d0 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -1254,10 +1254,8 @@ def fn(index: int) -> OpsValue: isinstance(reduction_numel, Integer) and V.graph.sizevars.size_hint(reduction_numel) < config.unroll_reductions_threshold - and (sympy_product(ranges) != 1 or device.type == "cuda") + and sympy_product(ranges) != 1 ): - # NB: This works around https://github.com/pytorch/pytorch/issues/140457 - # since turning reductions into pointwise ops can exacerbate this problem return Pointwise.create( device=device, dtype=dst_dtype, diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 07622dec1093..7364415be6da 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -104,40 +104,59 @@ def build_subgraph_buffer( args: The args that are passed into the subgraph. Contains both fixed and lifted inputs. subgraph: The Subgraph ir for which to produce the output node """ - from ..subgraph_lowering import PointwiseSubgraphLowering - - pw_subgraph = PointwiseSubgraphLowering( - subgraph.graph_module, root_graph_lowering=V.graph - ) - with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type] - pw_subgraph.run(*args) - - def convert_output_node_to_buffer(output): - if output is None: - return None - output_buffer = output - assert isinstance(output_buffer, TensorBox), ( - "The output node for flex attention's subgraph must be a TensorBox, but got: ", - type(output_buffer), - ) - assert isinstance(output_buffer.data, StorageBox), ( - "The output node for the flex attention subgraph must be a StorageBox, but got: ", - type(output_buffer), - ) - subgraph_buffer = ComputedBuffer( - name=None, - layout=FlexibleLayout( - device=output_buffer.data.get_device(), - dtype=output_buffer.data.get_dtype(), - size=output_buffer.data.get_size(), - ), - data=output_buffer.data.data, # type: ignore[arg-type] - ) - return subgraph_buffer - - # node.args[0] is either a single element or a list of elements - # representing all outputs of the function. - return tree_map(convert_output_node_to_buffer, pw_subgraph.graph_outputs) + cnt = 0 + env = {} + for node in subgraph.graph_module.graph.nodes: + # There are two classes of placeholder inpts that we need + # to handle differently. For the first n_scalar_inps inputs + # we expect that these placeholders were generated by the make_fx call + # in the flex Attention HOP. So we need to create a new placeholder + # TensorBox for each of these inputs. For the rest of the inputs we + # expect that these are lifted inputs that fill up the '*other_buffers' + # tuple and already have corresponding TensorBoxes passed in as args. + with V.graph.set_current_node(node): + if node.op == "placeholder": + env[node] = args[cnt] + cnt += 1 + elif node.op == "call_function": + # For call_function we use the default lowerings and pass in the + # already created TensorBoxes as args + + args, kwargs = tree_map( + lambda x: env[x] if x in env else x, (node.args, node.kwargs) + ) + env[node] = lowerings[node.target](*args, **kwargs) + elif node.op == "output": + + def convert_output_node_to_buffer(output): + if output is None: + return None + output_node = output + output_buffer = env[output_node] + assert isinstance(output_buffer, TensorBox), ( + "The output node for flex attention's subgraph must be a TensorBox, but got: ", + type(output_buffer), + ) + assert isinstance(output_buffer.data, StorageBox), ( + "The output node for the flex attention subgraph must be a StorageBox, but got: ", + type(output_buffer), + ) + subgraph_buffer = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=output_buffer.data.get_device(), + dtype=output_buffer.data.get_dtype(), + size=output_buffer.data.get_size(), + ), + data=output_buffer.data.data, # type: ignore[arg-type] + ) + return subgraph_buffer + + # node.args[0] is either a single element or a list of elements + # representing all outputs of the function. + return tree_map(convert_output_node_to_buffer, node.args[0]) + + raise ValueError("FlexAttention was passed a subgraph with no output node!") # Inner Triton functions shared by flex_attention & split-k decoding kernels. @@ -503,7 +522,7 @@ def forward_block_mn( ) | indent_except_first(2) }} if CHECK_BLOCK_BOUNDARY: - mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, float("-inf")) # apply mask for partially unmasked blocks post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) @@ -1739,8 +1758,6 @@ def flex_attention_backward(*args, **kwargs): joint_placeholder_inps = fwd_placeholder_inps + [ create_placeholder("grad_score_mod", dtype, device) ] - # Sometimes we have weird unused nodes here - joint_graph.graph_module.graph.eliminate_dead_code() joint_subgraph_buffer, *_ = build_subgraph_buffer( joint_placeholder_inps + list(score_mod_other_buffers), joint_graph ) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 1ca80369b562..2711ea2f4232 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -74,7 +74,7 @@ log = logging.getLogger(__name__) -lowerings: Dict[Callable[..., Any], Callable[..., Any]] = {} +lowerings: Dict[torch._ops.OpOverload, Callable[..., Any]] = {} # Use maybe_layout_constraints to access this dict, we lazily register tag-based layout constraints _maybe_layout_constraints: Dict[ torch._ops.OpOverload, Optional[Callable[..., Any]] diff --git a/torch/_inductor/subgraph_lowering.py b/torch/_inductor/subgraph_lowering.py index e0b36c419652..21145375ede8 100644 --- a/torch/_inductor/subgraph_lowering.py +++ b/torch/_inductor/subgraph_lowering.py @@ -13,6 +13,7 @@ from . import ir from .exc import SubgraphLoweringException from .ops_handler import SimpleCSEHandler +from .sizevars import SizeVarAllocator from .virtualized import ops, V, WrapperHandler @@ -21,11 +22,6 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter): - """ - Lowers a pointwise subgraph to a single set of buffers with a separate - lowering object. Errors if buffers are created unexpectedly - """ - graph_outputs: Optional[List[ir.IRNode]] def __init__( @@ -37,19 +33,18 @@ def __init__( self.graph_outputs = None self.root_graph = root_graph_lowering + @property + def sizevars(self) -> SizeVarAllocator: + return self.root_graph.sizevars + def mark_buffer_mutated(self, name: str) -> None: raise SubgraphLoweringException("Mutations are not supported in this context") def register_buffer(self, buffer: ir.Buffer) -> str: raise SubgraphLoweringException( - "Buffers cannot be created while lowering a pointwise subgraph. " - "This could be for a good reason (e.g. you're calling an op we can't codegen as a pointwise op), " - "but it could also be a bug. Please file a bug report if you think this should be supportable." + "Buffer creation is not supported in this context" ) - def __getattr__(self, name: str) -> Any: - return getattr(self.root_graph, name) - def call_function( self, target: Callable[[Any], Any], # type: ignore[override] @@ -61,11 +56,18 @@ def call_function( if target is operator.getitem and isinstance(args[0], (list, tuple, dict)): return super().call_function(target, args, kwargs) + assert isinstance(target, torch._ops.OpOverload) + if target not in lowerings: raise SubgraphLoweringException( f"{target} not supported in subgraph, (missing lowering)" ) + if torch.Tag.pointwise not in target.tags: + raise SubgraphLoweringException( + f"Only pointwise operators are supported in this context, but got {target}" + ) + return lowerings[target](*args, **kwargs) def output(self, target: str, args: Tuple[Any], kwargs: Dict[str, Any]) -> None: # type: ignore[override]