Skip to content

Commit

Permalink
Revert "Made FlexAttention error on subgraph lowering failure (#140331)"
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchmergebot committed Nov 15, 2024
1 parent 55f1959 commit de34f58
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 123 deletions.
55 changes: 0 additions & 55 deletions test/inductor/test_flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2965,61 +2965,6 @@ def test_mixed_device_error_message(self):
with self.assertRaisesRegex(ValueError, expected_error_message):
flex_attention(query, key, value)

@supported_platform
def test_captured_wrong_device_error_message(self):
means = torch.randn(64, 3).cuda()
length_scales = torch.logspace(0.001, 0.1, 8)

def euclidean_dist_pos_embed(score, b, h, q_idx, k_idx):
q_pos = means[q_idx]
k_pos = means[k_idx]
dist = (q_pos - k_pos).pow(2).sum(-1).sqrt()
scale = length_scales[h]
inv_dist = torch.exp(-dist / scale)
return inv_dist * score

expected_error_message = "Buffers cannot be created"

q, k, v = (torch.randn(1, 8, 64, 64, device="cuda") for _ in range(3))
with self.assertRaisesRegex(RuntimeError, expected_error_message):
torch.compile(flex_attention)(q, k, v, score_mod=euclidean_dist_pos_embed)

@supported_platform
def test_cant_lower_error_message(self):
# We can't lower a 256-element reduction inside a pointwise reduction
means = torch.randn(64, 256).cuda()
length_scales = torch.logspace(0.001, 0.1, 8).cuda()

def euclidean_dist_pos_embed(score, b, h, q_idx, k_idx):
q_pos = means[q_idx]
k_pos = means[k_idx]
dist = (q_pos - k_pos).pow(2).sum(-1).sqrt()
scale = length_scales[h]
inv_dist = torch.exp(-dist / scale)
return inv_dist * score

expected_error_message = "Buffers cannot be created"

q, k, v = (torch.randn(1, 8, 64, 64, device="cuda") for _ in range(3))
with self.assertRaisesRegex(RuntimeError, expected_error_message):
torch.compile(flex_attention)(q, k, v, score_mod=euclidean_dist_pos_embed)

@supported_platform
def test_reduction_unrolled(self):
# We can't lower a 256-element reduction inside a pointwise reduction
means = torch.randn(S, 3).cuda()
length_scales = torch.logspace(0.001, 0.1, H).cuda()

def euclidean_dist_pos_embed(score, b, h, q_idx, k_idx):
q_pos = means[q_idx]
k_pos = means[k_idx]
dist = (q_pos - k_pos).pow(2).sum(-1).sqrt()
scale = length_scales[h]
inv_dist = torch.exp(-dist / scale)
return inv_dist * score

self.run_test(euclidean_dist_pos_embed, torch.bfloat16)

@supported_platform
def test_invalid_block_size(self):
# Create tensors on different devices
Expand Down
4 changes: 0 additions & 4 deletions test/inductor/test_torchinductor_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,10 +443,6 @@ def wrapper_noop_set_seed(op, *args, **kwargs):
"atol": 3e-4,
"rtol": 0.002,
},
("nn.functional.triplet_margin_with_distance_loss", f16): {
"atol": 3e-4,
"rtol": 0.003,
},
("softmax", f16): {"atol": 1e-4, "rtol": 0.02},
("polygamma.polygamma_n_0", f32): {"atol": 1e-3, "rtol": 1e-4},
("polygamma.polygamma_n_1", f32): {"atol": 1e-3, "rtol": 1e-4},
Expand Down
24 changes: 12 additions & 12 deletions torch/_higher_order_ops/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,8 @@ def flex_attention_functionalize(
assert isinstance(mask_mod_other_buffers_unwrapped, tuple)

example_vals = (
[query_unwrapped.new_zeros(())]
+ [query_unwrapped.new_zeros((), dtype=torch.int) for _ in range(4)]
[torch.zeros((), dtype=query.dtype)]
+ [torch.zeros((), dtype=torch.int) for _ in range(4)]
+ list(score_mod_other_buffers_unwrapped)
)
with ctx.redispatch_to_next() as m:
Expand Down Expand Up @@ -710,11 +710,11 @@ def flex_attention_autograd(
input_requires_grad = any(t.requires_grad for t in (query, key, value))
if torch.is_grad_enabled() and input_requires_grad:
example_vals = (
query.new_zeros((), requires_grad=input_requires_grad),
query.new_zeros((), dtype=torch.int),
query.new_zeros((), dtype=torch.int),
query.new_zeros((), dtype=torch.int),
query.new_zeros((), dtype=torch.int),
torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad),
torch.zeros((), dtype=torch.int),
torch.zeros((), dtype=torch.int),
torch.zeros((), dtype=torch.int),
torch.zeros((), dtype=torch.int),
)
fw_graph, bw_graph = create_fw_bw_graph(
score_mod, example_vals, score_mod_other_buffers
Expand Down Expand Up @@ -930,11 +930,11 @@ def trace_flex_attention_backward(
mask_mod_other_buffers,
)

fw_example_vals = [query.new_zeros((), requires_grad=query.requires_grad)] + [
query.new_zeros((), dtype=torch.int) for _ in range(4)
]
bw_example_vals = fw_example_vals + [query.new_zeros(())]
mask_example_vals = [query.new_zeros((), dtype=torch.int) for _ in range(4)]
fw_example_vals = [
torch.zeros((), dtype=query.dtype, requires_grad=query.requires_grad)
] + [torch.zeros((), dtype=torch.int) for _ in range(4)]
bw_example_vals = fw_example_vals + [torch.zeros((), dtype=query.dtype)]
mask_example_vals = [torch.zeros((), dtype=torch.int) for _ in range(4)]
mask_graph = block_mask[-1]
with TransformGetItemToIndex():
fw_graph = reenter_make_fx(fw_graph)(*fw_example_vals, *score_mod_other_buffers)
Expand Down
4 changes: 1 addition & 3 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,10 +1254,8 @@ def fn(index: int) -> OpsValue:
isinstance(reduction_numel, Integer)
and V.graph.sizevars.size_hint(reduction_numel)
< config.unroll_reductions_threshold
and (sympy_product(ranges) != 1 or device.type == "cuda")
and sympy_product(ranges) != 1
):
# NB: This works around https://github.com/pytorch/pytorch/issues/140457
# since turning reductions into pointwise ops can exacerbate this problem
return Pointwise.create(
device=device,
dtype=dst_dtype,
Expand Down
91 changes: 54 additions & 37 deletions torch/_inductor/kernel/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,40 +104,59 @@ def build_subgraph_buffer(
args: The args that are passed into the subgraph. Contains both fixed and lifted inputs.
subgraph: The Subgraph ir for which to produce the output node
"""
from ..subgraph_lowering import PointwiseSubgraphLowering

pw_subgraph = PointwiseSubgraphLowering(
subgraph.graph_module, root_graph_lowering=V.graph
)
with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type]
pw_subgraph.run(*args)

def convert_output_node_to_buffer(output):
if output is None:
return None
output_buffer = output
assert isinstance(output_buffer, TensorBox), (
"The output node for flex attention's subgraph must be a TensorBox, but got: ",
type(output_buffer),
)
assert isinstance(output_buffer.data, StorageBox), (
"The output node for the flex attention subgraph must be a StorageBox, but got: ",
type(output_buffer),
)
subgraph_buffer = ComputedBuffer(
name=None,
layout=FlexibleLayout(
device=output_buffer.data.get_device(),
dtype=output_buffer.data.get_dtype(),
size=output_buffer.data.get_size(),
),
data=output_buffer.data.data, # type: ignore[arg-type]
)
return subgraph_buffer

# node.args[0] is either a single element or a list of elements
# representing all outputs of the function.
return tree_map(convert_output_node_to_buffer, pw_subgraph.graph_outputs)
cnt = 0
env = {}
for node in subgraph.graph_module.graph.nodes:
# There are two classes of placeholder inpts that we need
# to handle differently. For the first n_scalar_inps inputs
# we expect that these placeholders were generated by the make_fx call
# in the flex Attention HOP. So we need to create a new placeholder
# TensorBox for each of these inputs. For the rest of the inputs we
# expect that these are lifted inputs that fill up the '*other_buffers'
# tuple and already have corresponding TensorBoxes passed in as args.
with V.graph.set_current_node(node):
if node.op == "placeholder":
env[node] = args[cnt]
cnt += 1
elif node.op == "call_function":
# For call_function we use the default lowerings and pass in the
# already created TensorBoxes as args

args, kwargs = tree_map(
lambda x: env[x] if x in env else x, (node.args, node.kwargs)
)
env[node] = lowerings[node.target](*args, **kwargs)
elif node.op == "output":

def convert_output_node_to_buffer(output):
if output is None:
return None
output_node = output
output_buffer = env[output_node]
assert isinstance(output_buffer, TensorBox), (
"The output node for flex attention's subgraph must be a TensorBox, but got: ",
type(output_buffer),
)
assert isinstance(output_buffer.data, StorageBox), (
"The output node for the flex attention subgraph must be a StorageBox, but got: ",
type(output_buffer),
)
subgraph_buffer = ComputedBuffer(
name=None,
layout=FlexibleLayout(
device=output_buffer.data.get_device(),
dtype=output_buffer.data.get_dtype(),
size=output_buffer.data.get_size(),
),
data=output_buffer.data.data, # type: ignore[arg-type]
)
return subgraph_buffer

# node.args[0] is either a single element or a list of elements
# representing all outputs of the function.
return tree_map(convert_output_node_to_buffer, node.args[0])

raise ValueError("FlexAttention was passed a subgraph with no output node!")


# Inner Triton functions shared by flex_attention & split-k decoding kernels.
Expand Down Expand Up @@ -503,7 +522,7 @@ def forward_block_mn(
) | indent_except_first(2) }}
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, float("-inf"))
# apply mask for partially unmasked blocks
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
Expand Down Expand Up @@ -1739,8 +1758,6 @@ def flex_attention_backward(*args, **kwargs):
joint_placeholder_inps = fwd_placeholder_inps + [
create_placeholder("grad_score_mod", dtype, device)
]
# Sometimes we have weird unused nodes here
joint_graph.graph_module.graph.eliminate_dead_code()
joint_subgraph_buffer, *_ = build_subgraph_buffer(
joint_placeholder_inps + list(score_mod_other_buffers), joint_graph
)
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@


log = logging.getLogger(__name__)
lowerings: Dict[Callable[..., Any], Callable[..., Any]] = {}
lowerings: Dict[torch._ops.OpOverload, Callable[..., Any]] = {}
# Use maybe_layout_constraints to access this dict, we lazily register tag-based layout constraints
_maybe_layout_constraints: Dict[
torch._ops.OpOverload, Optional[Callable[..., Any]]
Expand Down
24 changes: 13 additions & 11 deletions torch/_inductor/subgraph_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from . import ir
from .exc import SubgraphLoweringException
from .ops_handler import SimpleCSEHandler
from .sizevars import SizeVarAllocator
from .virtualized import ops, V, WrapperHandler


Expand All @@ -21,11 +22,6 @@


class PointwiseSubgraphLowering(torch.fx.Interpreter):
"""
Lowers a pointwise subgraph to a single set of buffers with a separate
lowering object. Errors if buffers are created unexpectedly
"""

graph_outputs: Optional[List[ir.IRNode]]

def __init__(
Expand All @@ -37,19 +33,18 @@ def __init__(
self.graph_outputs = None
self.root_graph = root_graph_lowering

@property
def sizevars(self) -> SizeVarAllocator:
return self.root_graph.sizevars

def mark_buffer_mutated(self, name: str) -> None:
raise SubgraphLoweringException("Mutations are not supported in this context")

def register_buffer(self, buffer: ir.Buffer) -> str:
raise SubgraphLoweringException(
"Buffers cannot be created while lowering a pointwise subgraph. "
"This could be for a good reason (e.g. you're calling an op we can't codegen as a pointwise op), "
"but it could also be a bug. Please file a bug report if you think this should be supportable."
"Buffer creation is not supported in this context"
)

def __getattr__(self, name: str) -> Any:
return getattr(self.root_graph, name)

def call_function(
self,
target: Callable[[Any], Any], # type: ignore[override]
Expand All @@ -61,11 +56,18 @@ def call_function(
if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
return super().call_function(target, args, kwargs)

assert isinstance(target, torch._ops.OpOverload)

if target not in lowerings:
raise SubgraphLoweringException(
f"{target} not supported in subgraph, (missing lowering)"
)

if torch.Tag.pointwise not in target.tags:
raise SubgraphLoweringException(
f"Only pointwise operators are supported in this context, but got {target}"
)

return lowerings[target](*args, **kwargs)

def output(self, target: str, args: Tuple[Any], kwargs: Dict[str, Any]) -> None: # type: ignore[override]
Expand Down

0 comments on commit de34f58

Please sign in to comment.