Skip to content

Commit

Permalink
Adds parallel cyclic reduction (PCR) as an option for solving batched…
Browse files Browse the repository at this point in the history
… time systems (#25)

Parallel cyclic reduction in theory has quite good scaling/vectorization on the GPU -- better than Thomas's algorithm anyway. In practice it seems to perform well for problems with a small number of coupled ODEs. This PR adds the algorithm and updates some of the code infrastructure to switch between different linear solver approaches.
  • Loading branch information
reverendbedford authored May 4, 2023
1 parent 6b17430 commit 14cf3a9
Show file tree
Hide file tree
Showing 7 changed files with 514 additions and 128 deletions.
14 changes: 10 additions & 4 deletions examples/ode/damping.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,12 @@ def random_walk(time, mean, scale, mag):
# Basic parameters
n_chain = 5 # Number of spring-dashpot-mass elements
n_samples = 5 # Number of random force samples
n_time = 500 # Number of time steps
n_time = 512+1 # Number of time steps
integration_method = 'backward-euler'
direct_solver = "thomas" # Batched, block, bidiagonal direct solver method

# Time chunking -- best value may vary on your system
n_chunk = 100
n_chunk = 2**7

# Ending time
t_end = 1.0
Expand Down Expand Up @@ -215,7 +216,9 @@ def random_walk(time, mean, scale, mag):

# Generate the data
with torch.no_grad():
y_data = ode.odeint(model, y0, time, method = integration_method, block_size = n_chunk)
y_data = ode.odeint(model, y0, time, method = integration_method,
block_size = n_chunk,
direct_solve_method = direct_solver)

# The observations will just be the first entry
observable = y_data[...,0]
Expand All @@ -239,7 +242,10 @@ def random_walk(time, mean, scale, mag):
optim = torch.optim.Adam(ode_model.parameters(), lr)
def closure():
optim.zero_grad()
pred = ode.odeint_adjoint(ode_model, y0, time, method = integration_method, block_size = n_chunk)
pred = ode.odeint_adjoint(ode_model, y0, time,
method = integration_method,
block_size = n_chunk,
direct_solve_method = direct_solver)
obs = pred[...,0]
lossv = loss(obs, observable)
lossv.backward()
Expand Down
82 changes: 82 additions & 0 deletions examples/performance/linear_solve_profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#!/usr/bin/env python3

import time
import random

import matplotlib.pyplot as plt

import numpy as np
import torch
import pandas as pd

from pyoptmat import chunktime

# Use doubles
torch.set_default_tensor_type(torch.DoubleTensor)

# Select device to run on
if torch.cuda.is_available():
dev = "cuda:0"
else:
dev = "cpu"
device = torch.device(dev)

def run_time(operator, D, L, v, repeat = 1):
times = []
for i in range(repeat):
t1 = time.time()
op = operator(D.clone(), L.clone())
x = op(v.clone())
times.append(time.time() - t1)

return np.mean(times)

if __name__ == "__main__":
# Number of repeated trials to average over
avg = 3
# Size of the blocks in the matrix
nsize = 10
# Batch size: number of matrices to solve at once
nbat = 10
# Maximum number of blocks in the matrix
max_blk = 1000
# Number of samples in range(1,max_blk) to sample
num_samples = 10

nblks = sorted(random.sample(list(range(1,max_blk)), num_samples))

methods = [chunktime.BidiagonalThomasFactorization, chunktime.BidiagonalPCRFactorization,
lambda A, B: chunktime.BidiagonalHybridFactorization(A, B, min_size = 8),
lambda A, B: chunktime.BidiagonalHybridFactorization(A, B, min_size = 16),
lambda A, B: chunktime.BidiagonalHybridFactorization(A, B, min_size = 32),
lambda A, B: chunktime.BidiagonalHybridFactorization(A, B, min_size = 64),
lambda A, B: chunktime.BidiagonalHybridFactorization(A, B, min_size = 128)]

method_names = ["Thomas", "PCR", "Hybrid, n = 8", "Hybrid, n = 16", "Hybrid, n = 32",
"Hybrid, n = 64", "Hybrid, n = 128"]

nmethods = len(methods)
ncase = len(nblks)

times = np.zeros((nmethods, ncase))

# Do this once to warm up the GPU, it seems to matter
run_time(methods[0], torch.rand(3, nbat, nsize, nsize, device = device),
torch.rand(2, nbat, nsize, nsize, device = device) / 10.0,
torch.rand(nbat, 3 * nsize, device = device))

for i,nblk in enumerate(nblks):
print(nblk)
D = torch.rand(nblk, nbat, nsize, nsize, device = device)
L = torch.rand(nblk - 1, nbat, nsize, nsize, device = device) / 10.0

v = torch.rand(nbat, nblk * nsize, device = device)
for j, method in enumerate(methods):
times[j,i] = run_time(method, D, L, v, repeat = avg)

data = pd.DataFrame(data = times.T, index = nblks, columns = method_names)
data.avg = avg
data.nsize = nsize
data.nbat = nbat

data.to_csv(f"{nbat}_{nsize}.csv")
5 changes: 4 additions & 1 deletion examples/structural-inference/tension/statistical/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def make(n, eta, s0, R, d, **kwargs):
if __name__ == "__main__":
# Number of vectorized time steps
time_chunk_size = 40
# Method to use to solve linearized implicit systems
linear_solve_method = "pcr"

# 1) Load the data for the variance of interest,
# cut down to some number of samples, and flatten
Expand Down Expand Up @@ -87,7 +89,8 @@ def make(n, eta, s0, R, d, **kwargs):

# 3) Create the actual model
model = optimize.HierarchicalStatisticalModel(
lambda *args, **kwargs: make(*args, block_size = time_chunk_size, **kwargs),
lambda *args, **kwargs: make(*args, block_size = time_chunk_size, direct_solve_method = linear_solve_method,
**kwargs),
names, loc_loc_priors, loc_scale_priors, scale_scale_priors, eps
).to(device)

Expand Down
Loading

0 comments on commit 14cf3a9

Please sign in to comment.