Skip to content

Commit

Permalink
Updated tests and blacked
Browse files Browse the repository at this point in the history
  • Loading branch information
reverendbedford committed Jan 12, 2024
1 parent ed0c0a1 commit ea9f582
Show file tree
Hide file tree
Showing 29 changed files with 1,130 additions and 780 deletions.
255 changes: 165 additions & 90 deletions examples/ode/damping.py

Large diffs are not rendered by default.

324 changes: 186 additions & 138 deletions examples/ode/neuron.py

Large diffs are not rendered by default.

12 changes: 7 additions & 5 deletions examples/ode/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
# Prior for the noise
eps_prior = 0.1 # Just measure variance in data...


def model_act(times):
"""
times: ntime x nbatch
Expand All @@ -77,7 +78,7 @@ def model_act(times):
torch.stack(
(
v * torch.cos(a) * times,
v * torch.sin(a) * times - 0.5 * g * times ** 2.0,
v * torch.sin(a) * times - 0.5 * g * times**2.0,
)
).T,
eps_act,
Expand All @@ -88,7 +89,7 @@ def model_act(times):


class Integrator(pyro.nn.PyroModule):
def __init__(self, eqn, y0, extra_params=[], block_size = 1):
def __init__(self, eqn, y0, extra_params=[], block_size=1):
super().__init__()
self.eqn = eqn
self.y0 = y0
Expand All @@ -100,7 +101,7 @@ def forward(self, times):
self.eqn,
self.y0,
times,
block_size = self.block_size,
block_size=self.block_size,
extra_params=self.extra_params,
)

Expand Down Expand Up @@ -274,8 +275,9 @@ def gen_extra(self):
pyro.clear_param_store()

def maker(v, a, **kwargs):
return Integrator(ODE(v, a), torch.zeros(nsamples, 2),
block_size = time_block, **kwargs)
return Integrator(
ODE(v, a), torch.zeros(nsamples, 2), block_size=time_block, **kwargs
)

# Setup the model
model = Model(
Expand Down
44 changes: 34 additions & 10 deletions examples/ode/vanderpohl.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,21 @@ def forward(self, t, y):

model = VanderPolODE(mu)

res_exp = ode.odeint(model, y0, times, method="forward-euler",
block_size = time_chunk, guess_type = "previous")
res_exp = ode.odeint(
model,
y0,
times,
method="forward-euler",
block_size=time_chunk,
guess_type="previous",
)
res_imp = ode.odeint(
model, y0, times, method="backward-euler",
block_size = time_chunk, guess_type = "previous"
model,
y0,
times,
method="backward-euler",
block_size=time_chunk,
guess_type="previous",
)

plt.figure()
Expand All @@ -103,11 +113,21 @@ def forward(self, t, y):

model = VanderPolODE(mu)

res_exp = ode.odeint(model, y0, times, method="forward-euler",
block_size = time_chunk, guess_type = "previous")
res_exp = ode.odeint(
model,
y0,
times,
method="forward-euler",
block_size=time_chunk,
guess_type="previous",
)
res_imp = ode.odeint(
model, y0, times, method="backward-euler",
block_size = time_chunk, guess_type = "previous"
model,
y0,
times,
method="backward-euler",
block_size=time_chunk,
guess_type="previous",
)

plt.figure()
Expand All @@ -126,8 +146,12 @@ def forward(self, t, y):
model = VanderPolODE(mu)

res_imp = ode.odeint(
model, y0, times, method="backward-euler",
block_size = time_chunk, guess_type = "previous"
model,
y0,
times,
method="backward-euler",
block_size=time_chunk,
guess_type="previous",
)

plt.figure()
Expand Down
63 changes: 39 additions & 24 deletions examples/performance/linear_solve_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
dev = "cpu"
device = torch.device(dev)

def run_time(operator, D, L, v, repeat = 1):

def run_time(operator, D, L, v, repeat=1):
times = []
for i in range(repeat):
t1 = time.time()
Expand All @@ -31,6 +32,7 @@ def run_time(operator, D, L, v, repeat = 1):

return np.mean(times)


if __name__ == "__main__":
# Number of repeated trials to average over
avg = 3
Expand All @@ -42,41 +44,54 @@ def run_time(operator, D, L, v, repeat = 1):
max_blk = 1000
# Number of samples in range(1,max_blk) to sample
num_samples = 10

nblks = sorted(random.sample(list(range(1,max_blk)), num_samples))

methods = [chunktime.BidiagonalThomasFactorization, chunktime.BidiagonalPCRFactorization,
lambda A, B: chunktime.BidiagonalHybridFactorization(A, B, min_size = 8),
lambda A, B: chunktime.BidiagonalHybridFactorization(A, B, min_size = 16),
lambda A, B: chunktime.BidiagonalHybridFactorization(A, B, min_size = 32),
lambda A, B: chunktime.BidiagonalHybridFactorization(A, B, min_size = 64),
lambda A, B: chunktime.BidiagonalHybridFactorization(A, B, min_size = 128)]

method_names = ["Thomas", "PCR", "Hybrid, n = 8", "Hybrid, n = 16", "Hybrid, n = 32",
"Hybrid, n = 64", "Hybrid, n = 128"]
nblks = sorted(random.sample(list(range(1, max_blk)), num_samples))

methods = [
chunktime.BidiagonalThomasFactorization,
chunktime.BidiagonalPCRFactorization,
lambda A, B: chunktime.BidiagonalHybridFactorization(A, B, min_size=8),
lambda A, B: chunktime.BidiagonalHybridFactorization(A, B, min_size=16),
lambda A, B: chunktime.BidiagonalHybridFactorization(A, B, min_size=32),
lambda A, B: chunktime.BidiagonalHybridFactorization(A, B, min_size=64),
lambda A, B: chunktime.BidiagonalHybridFactorization(A, B, min_size=128),
]

method_names = [
"Thomas",
"PCR",
"Hybrid, n = 8",
"Hybrid, n = 16",
"Hybrid, n = 32",
"Hybrid, n = 64",
"Hybrid, n = 128",
]

nmethods = len(methods)
ncase = len(nblks)

times = np.zeros((nmethods, ncase))

# Do this once to warm up the GPU, it seems to matter
run_time(methods[0], torch.rand(3, nbat, nsize, nsize, device = device),
torch.rand(2, nbat, nsize, nsize, device = device) / 10.0,
torch.rand(nbat, 3 * nsize, device = device))

for i,nblk in enumerate(nblks):
# Do this once to warm up the GPU, it seems to matter
run_time(
methods[0],
torch.rand(3, nbat, nsize, nsize, device=device),
torch.rand(2, nbat, nsize, nsize, device=device) / 10.0,
torch.rand(nbat, 3 * nsize, device=device),
)

for i, nblk in enumerate(nblks):
print(nblk)
D = torch.rand(nblk, nbat, nsize, nsize, device = device)
L = torch.rand(nblk - 1, nbat, nsize, nsize, device = device) / 10.0
D = torch.rand(nblk, nbat, nsize, nsize, device=device)
L = torch.rand(nblk - 1, nbat, nsize, nsize, device=device) / 10.0

v = torch.rand(nbat, nblk * nsize, device = device)
v = torch.rand(nbat, nblk * nsize, device=device)
for j, method in enumerate(methods):
times[j,i] = run_time(method, D, L, v, repeat = avg)
times[j, i] = run_time(method, D, L, v, repeat=avg)

data = pd.DataFrame(data = times.T, index = nblks, columns = method_names)
data = pd.DataFrame(data=times.T, index=nblks, columns=method_names)
data.avg = avg
data.nsize = nsize
data.nbat = nbat

data.to_csv(f"{nbat}_{nsize}.csv")
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,11 @@ def make(n, eta, s0, R, d, C, g, **kwargs):
print("")

# 3) Create the actual model
model = optimize.DeterministicModel(lambda *args, **kwargs: make(*args, block_size = time_chunk_size, **kwargs), names, ics)
model = optimize.DeterministicModel(
lambda *args, **kwargs: make(*args, block_size=time_chunk_size, **kwargs),
names,
ics,
)

# 4) Setup the optimizer
niter = 200
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ def make(n, eta, s0, R, d, C, g, **kwargs):

# 3) Create the actual model
model = optimize.HierarchicalStatisticalModel(
lambda *args, **kwargs: make(*args, block_size = time_chunk_size, **kwargs),
lambda *args, **kwargs: make(*args, block_size=time_chunk_size, **kwargs),
names,
loc_loc_priors,
loc_scale_priors,
scale_scale_priors,
eps,
include_noise=False,
use_cached_guess= True
use_cached_guess=True,
).to(device)

# 4) Get the guide
Expand Down
Loading

0 comments on commit ea9f582

Please sign in to comment.