Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address torch.compile graph breaks in models #452

Open
akihironitta opened this issue Sep 18, 2024 · 0 comments
Open

Address torch.compile graph breaks in models #452

akihironitta opened this issue Sep 18, 2024 · 0 comments

Comments

@akihironitta
Copy link
Member

akihironitta commented Sep 18, 2024

Currently, models in torch_frame.nn have a number of graph breaks, but we should able to remove all or most of them to maximise performance optimisation opportunities. Specifically, the goal is to address as many graph breaks as possible in this test case:

@pytest.mark.parametrize(
"model_cls, model_kwargs, stypes, expected_graph_breaks",
[
pytest.param(
FTTransformer,
dict(channels=8),
None,
2,
id="FTTransformer",
),
pytest.param(ResNet, dict(channels=8), None, 2, id="ResNet"),
pytest.param(
TabNet,
dict(
split_feat_channels=2,
split_attn_channels=2,
gamma=0.1,
),
None,
7,
id="TabNet",
),
pytest.param(
TabTransformer,
dict(
channels=8,
num_heads=2,
encoder_pad_size=2,
attn_dropout=0.5,
ffn_dropout=0.5,
),
None,
4,
id="TabTransformer",
),
pytest.param(
Trompt,
dict(channels=8, num_prompts=2),
None,
16,
id="Trompt",
),
pytest.param(
ExcelFormer,
dict(in_channels=8, num_cols=3, num_heads=1),
[stype.numerical],
4,
id="ExcelFormer",
),
],
)
def test_compile_graph_break(
model_cls,
model_kwargs,
stypes,
expected_graph_breaks,
):
torch._dynamo.config.suppress_errors = True
dataset = FakeDataset(
num_rows=10,
with_nan=False,
stypes=stypes or [stype.categorical, stype.numerical],
)
dataset.materialize()
tf = dataset.tensor_frame
model = model_cls(
out_channels=1,
num_layers=2,
col_stats=dataset.col_stats,
col_names_dict=tf.col_names_dict,
**model_kwargs,
)
explanation = torch._dynamo.explain(model)(tf)
assert explanation.graph_break_count <= expected_graph_breaks


Note

torch._dynamo.explain() doesn't show graph break reasons even when there're graph breaks. Instead, I suggest finding out graph break reasons with torch logs:

TORCH_LOGS=graph_breaks pytest test/nn/models/test_compile.py -k ExcelFormer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant