Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scan operation causing significant overhead #23

Open
mcnaughtonadm opened this issue May 22, 2024 · 8 comments
Open

Scan operation causing significant overhead #23

mcnaughtonadm opened this issue May 22, 2024 · 8 comments
Labels
bug Something isn't working

Comments

@mcnaughtonadm
Copy link
Collaborator

When applying the linlog model for use in pymc inference, we face quite significant inference time (2 days - 10 days). This is obviously not ideal for any productive application of the workflow. Utilizing pymc built-in profiling methods, we can narrow the issue down to the scan operation occurring in emll. Specifically for

def steady_state_pytensor(self, Ex, Ey=None, en=None, yn=None, method="scan"):
"""Calculate a the steady-state transformed metabolite concentrations
and fluxes using PyTensor.
Ex, Ey, en and yn should be pytensor matrices
solver: function
A function to solve Ax = b for a (possibly) singular A. Should
accept pytensor matrices A and b, and return a symbolic x.
"""
if Ey is None:
Ey = at.as_tensor_variable(Ey)
if isinstance(en, np.ndarray):
en = np.atleast_2d(en)
n_exp = en.shape[0]
else:
n_exp = en.shape.eval()[0]
if isinstance(yn, np.ndarray):
yn = np.atleast_2d(yn)
en = at.as_tensor_variable(en)
yn = at.as_tensor_variable(yn)
e_diag = en.dimshuffle(0, 1, "x") * np.diag(self.v_star)
N_rep = self.Nr.reshape((-1, *self.Nr.shape)).repeat(n_exp, axis=0)
N_hat = at.batched_dot(N_rep, e_diag)
inner_v = Ey.dot(yn.T).T + np.ones(self.nr, dtype=_floatX)
As = at.dot(N_hat, Ex)
bs = at.batched_dot(-N_hat, inner_v.dimshuffle(0, 1, "x"))
if method == "scan":
xn, _ = pytensor.scan(
lambda A, b: self.solve_pytensor(A, b), sequences=[As, bs], strict=True
)
else:
xn_list = [None] * n_exp
for i in range(n_exp):
xn_list[i] = self.solve_pytensor(As[i], bs[i])
xn = at.stack(xn_list)
vn = en * (np.ones(self.nr) + at.dot(Ex, xn.T).T + at.dot(Ey, yn.T).T)
return xn, vn

This contains the line

        if method == "scan":
            xn, _ = pytensor.scan(
                lambda A, b: self.solve_pytensor(A, b), sequences=[As, bs], strict=True
            )

which is where scan is being used by our code. We need to find a way to optimize this method of determining xn because it my have worked with theano, but it is definitely struggling with pytensor.

@mcnaughtonadm mcnaughtonadm added the bug Something isn't working label May 22, 2024
@djinnome
Copy link
Collaborator

djinnome commented May 22, 2024

Instead of solve_pytensor, should we be using the Cholesky solve? I thought that was the secret sauce.

def solve(self, A, b):
A_hat = A.T @ A + self.lambda_ * np.eye(A.shape[1])
b_hat = A.T @ b
cho = sp.linalg.cho_factor(A_hat)
return sp.linalg.cho_solve(cho, b_hat)

@pstjohn?

@djinnome
Copy link
Collaborator

Nevermind, that is Cholesky solve using scipy. I don't know if pytensor has cholesky solvers.

@djinnome
Copy link
Collaborator

Should we try HMC? It has gotten a lot faster

@djinnome
Copy link
Collaborator

Nevermind. You still have the same bottleneck

@mcnaughtonadm
Copy link
Collaborator Author

mcnaughtonadm commented May 22, 2024

Yea so for my initial "profiling", I ran the following snippet:

model.profile(model.logp()).summary()

which is agnostic of any inference, and only takes the log probability of a state in the pymc model itself. So any slowdown is occurring in the model formulation, and not the inference over the model.

But the slowdown does follow the model into the inference step, hence the problem.

@mcnaughtonadm
Copy link
Collaborator Author

Nevermind, that is Cholesky solve using scipy. I don't know if pytensor has cholesky solvers.

Looking around the PyTensor docs, they do seem to have other solving methods available. Mainly Cholesky and Triangular. The original Theano implementation also had access to these, just in a different way. Do you think a Cholesky solve of the system of equations would improve performance over a standard solve?

The CholeskySolve class can be found here.
https://github.com/pymc-devs/pytensor/blob/bb028ae2330433755b9d4aa32ab6e8d0c9f662fc/pytensor/tensor/slinalg.py#L237

The standard Solve class that the LeastSquaresSolve that emll uses is here:
https://github.com/pymc-devs/pytensor/blob/bb028ae2330433755b9d4aa32ab6e8d0c9f662fc/pytensor/tensor/slinalg.py#L366

@djinnome
Copy link
Collaborator

I would say it is worth a shot. Peter St John felt that the absence of the Cholesky solver in PyTorch and Tensorflow was a reason not to try porting emll to those frameworks.

@mcnaughtonadm
Copy link
Collaborator Author

I am also trying something else that I noticed. In theano, the Solve class inherits directly from Op. But in PyTensor there is a SolveBase that inherits directly from Op and Solve inherits from SolveBase. I am seeing if this is causing unnecessary computation and switching the LeastSquaresSolve in emll to pull from SolveBase over Solve.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants