Skip to content

Commit

Permalink
Add batch decomposition for torch.linalg.eigh (pytorch#110640)
Browse files Browse the repository at this point in the history
  • Loading branch information
guilhermeleobas authored and pytorchmergebot committed Oct 9, 2023
1 parent 201d02e commit 0a580da
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/functorch/BatchRulesDecompositions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
OP_DECOMPOSE(linalg_multi_dot);
OP_DECOMPOSE(linalg_norm);
OP_DECOMPOSE2(linalg_norm, ord_str);
OP_DECOMPOSE(linalg_eigh);
OP_DECOMPOSE(linalg_solve);
OP_DECOMPOSE(linalg_solve_ex);
OP_DECOMPOSE(linalg_svd);
Expand Down
4 changes: 2 additions & 2 deletions test/functorch/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3512,7 +3512,7 @@ def test():
xfail('sparse.mm', 'reduce'), # sparse
xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable autograd.Function
skip('_softmax_backward_data'),
skip('linalg.eigh', ''), # not unique, see test_linalg_eigh for manual test
skip('linalg.eigh', ''), # not always return the same result for the same input, see test_linalg_eigh for manual test
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
# ----------------------------------------------------------------------

Expand Down Expand Up @@ -3838,7 +3838,7 @@ def compute_A(out):
assert len(opinfos) > 0

for op in opinfos:
self.opinfo_vmap_test(device, torch.float, op, check_has_batch_rule=False,
self.opinfo_vmap_test(device, torch.float, op, check_has_batch_rule=True,
postprocess_fn=compute_A)

def test_slogdet(self, device):
Expand Down
1 change: 0 additions & 1 deletion test/functorch/test_vmap_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@
"aten::less_equal_.Scalar",
"aten::less_equal_.Tensor",
"aten::linalg_cond.p_str",
"aten::linalg_eigh",
"aten::linalg_eigh.eigvals",
"aten::linalg_lu_factor",
"aten::linalg_matrix_rank",
Expand Down

0 comments on commit 0a580da

Please sign in to comment.