Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support of scipy.linalg.solve_banded() #607

Merged
merged 4 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions autograd/scipy/linalg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import division
from functools import partial
import scipy.linalg

import autograd.numpy as anp
Expand Down Expand Up @@ -35,6 +36,62 @@ def vjp(g):
lambda ans, a, b, trans=0, lower=False, **kwargs:
lambda g: solve_triangular(a, g, trans=_flip(a, trans), lower=lower))

def grad_solve_banded(argnum, ans, l_and_u, a, b):

updim = lambda x: x if x.ndim == a.ndim else x[...,None]

def transpose_banded(l_and_u, a):

# Compute the transpose of a banded matrix.
# The transpose is itself a banded matrix.

num_rows = a.shape[0]

shifts = anp.arange(-l_and_u[1], l_and_u[0]+1)

T_a = anp.roll(a[:1, :], shifts[0])
for rr in range(1, num_rows):
T_a = anp.vstack([T_a, anp.flipud(anp.roll(a[rr:rr+1, :], shifts[rr]))])
T_a = anp.flipud(T_a)

T_l_and_u = anp.flip(l_and_u)

return T_l_and_u, T_a

def banded_dot(l_and_u, uu, vv):

# Compute tensor product of vectors uu and vv.
# Tensor product elements are resticted to the bands specified by l_and_u.

# TODO: replace the brute-force ravel() by smarter dimension handeling of uu and vv

# main diagonal
banded_uv = anp.ravel(uu)*anp.ravel(vv)

# stack below the sub-diagonals
for rr in range(1, l_and_u[0]+1):
banded_uv_rr = anp.hstack([anp.ravel(uu)[rr:]*anp.ravel(vv)[:-rr], anp.zeros(rr)])
banded_uv = anp.vstack([banded_uv, banded_uv_rr])

# stack above the sup-diagonals
for rr in range(1, l_and_u[1]+1):
banded_uv_rr = anp.hstack([anp.zeros(rr), anp.ravel(uu)[:-rr]*anp.ravel(vv)[rr:]])
banded_uv = anp.vstack([banded_uv_rr, banded_uv])

return(banded_uv)

T_l_and_u, T_a = transpose_banded(l_and_u, a)

if argnum == 1:
return lambda g: -banded_dot(l_and_u, updim(solve_banded(T_l_and_u, T_a, g)), anp.transpose(updim(ans)))
elif argnum == 2:
return lambda g: solve_banded(T_l_and_u, T_a, g)

defvjp(solve_banded,
partial(grad_solve_banded, 1),
partial(grad_solve_banded, 2),
argnums=[1, 2])

def _jvp_sqrtm(dA, ans, A, disp=True, blocksize=64):
assert disp, "sqrtm jvp not implemented for disp=False"
return solve_sylvester(ans, ans, dA)
Expand Down
1 change: 1 addition & 0 deletions tests/test_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,4 @@ def test_odeint():
def test_sqrtm(): combo_check(spla.sqrtm, modes=['fwd'], order=2)([R(3, 3)])
def test_sqrtm(): combo_check(symmetrize_matrix_arg(spla.sqrtm, 0), modes=['fwd', 'rev'], order=2)([R(3, 3)])
def test_solve_sylvester(): combo_check(spla.solve_sylvester, [0, 1, 2], modes=['rev', 'fwd'], order=2)([R(3, 3)], [R(3, 3)], [R(3, 3)])
def test_solve_banded(): combo_check(spla.solve_banded, [1, 2], modes=['rev'], order=1)([(1, 1)], [R(3,5)], [R(5)])
Loading