Skip to content

Commit

Permalink
squash commits
Browse files Browse the repository at this point in the history
  • Loading branch information
and-rewsmith committed Oct 30, 2024
1 parent 398c7c4 commit 39d6c88
Show file tree
Hide file tree
Showing 15 changed files with 3,974 additions and 1 deletion.
Binary file added baseline.prof
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
198 changes: 198 additions & 0 deletions playground/bench.ipynb

Large diffs are not rendered by default.

214 changes: 214 additions & 0 deletions playground/math_multilayer_bptt.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "Given groups=1, weight of size [1, 1, 10], expected input[1, 2, 10] to have 1 channels, but got 2 channels instead",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[2], line 34\u001b[0m\n\u001b[1;32m 31\u001b[0m decay_filter \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mexp(\u001b[39m-\u001b[39mtorch\u001b[39m.\u001b[39marange(timesteps)\u001b[39m.\u001b[39mfloat())\u001b[39m.\u001b[39mview(\u001b[39m1\u001b[39m, \u001b[39m1\u001b[39m, \u001b[39m-\u001b[39m\u001b[39m1\u001b[39m) \u001b[39m# Shape [1, 1, timesteps]\u001b[39;00m\n\u001b[1;32m 33\u001b[0m \u001b[39m# Perform convolution over all layers at once\u001b[39;00m\n\u001b[0;32m---> 34\u001b[0m input_contribution \u001b[39m=\u001b[39m F\u001b[39m.\u001b[39;49mconv1d(inputs, decay_filter, padding\u001b[39m=\u001b[39;49mtimesteps\u001b[39m-\u001b[39;49m\u001b[39m1\u001b[39;49m)\n\u001b[1;32m 36\u001b[0m \u001b[39m# Combine decayed initial state and input contributions\u001b[39;00m\n\u001b[1;32m 37\u001b[0m h_total \u001b[39m=\u001b[39m h_exp\u001b[39m.\u001b[39munsqueeze(\u001b[39m1\u001b[39m) \u001b[39m+\u001b[39m input_contribution\u001b[39m.\u001b[39mpermute(\u001b[39m0\u001b[39m, \u001b[39m2\u001b[39m, \u001b[39m1\u001b[39m) \u001b[39m# Combine along the time axis\u001b[39;00m\n",
"\u001b[0;31mRuntimeError\u001b[0m: Given groups=1, weight of size [1, 1, 10], expected input[1, 2, 10] to have 1 channels, but got 2 channels instead"
]
}
],
"source": [
"import torch\n",
"import torch.nn.functional as F\n",
"\n",
"# Network parameters\n",
"timesteps = 10\n",
"input_size = 2 # Two layers\n",
"hidden_size = 2 # Two layers\n",
"batch_size = 1\n",
"\n",
"# Initialize activations and inputs for two layers\n",
"h_init = torch.randn(batch_size, hidden_size) # Initial activations h_1(0), h_2(0)\n",
"inputs = torch.randn(batch_size, timesteps, hidden_size) # Inputs I_1(t), I_2(t)\n",
"\n",
"# Define the connection matrix A (self-decay + inter-layer connections)\n",
"# A = [[alpha1, beta1],\n",
"# [beta2, alpha2]]\n",
"A = torch.tensor([[0.9, 0.1], # alpha1 = 0.9, beta1 = 0.1\n",
" [0.2, 0.8]]) # beta2 = 0.2, alpha2 = 0.8\n",
"\n",
"# Step 1: Exponentially decay the initial state using matrix exponential\n",
"# h(t) = exp(A * t) * h(0)\n",
"exp_A = torch.matrix_exp(A) # Matrix exponential of A\n",
"h_exp = torch.matmul(exp_A, h_init.T).T # Decayed initial state\n",
"\n",
"# Step 2: Fully parallelized convolution to handle the input contribution over time\n",
"\n",
"# Reshape inputs to be batch x channels x timesteps for conv1d\n",
"inputs = inputs.permute(0, 2, 1) # Now it's [batch_size, layers (channels), timesteps]\n",
"\n",
"# Compute the decay filter\n",
"decay_filter = torch.exp(-torch.arange(timesteps).float()).view(1, 1, -1) # Shape [1, 1, timesteps]\n",
"\n",
"# Perform convolution over all layers at once\n",
"input_contribution = F.conv1d(inputs, decay_filter, padding=timesteps-1)\n",
"\n",
"# Combine decayed initial state and input contributions\n",
"h_total = h_exp.unsqueeze(1) + input_contribution.permute(0, 2, 1) # Combine along the time axis\n",
"\n",
"# Output activations for each layer across all time steps\n",
"print(\"Activations across time for both layers:\")\n",
"print(h_total)\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Ground Truth (Iterative):\n",
"tensor([[ 1.0000, 1.0000],\n",
" [ 2.4000, 2.3000],\n",
" [ 5.3100, 5.0400],\n",
" [ 10.2990, 9.6870],\n",
" [ 18.1126, 16.8991],\n",
" [ 29.7509, 27.5756],\n",
" [ 46.5636, 42.9359],\n",
" [ 70.3752, 64.6305],\n",
" [103.6529, 94.8920],\n",
" [149.7336, 136.7401]])\n",
"\n",
"Parallel Computation:\n",
"tensor([[ 2.4000, 2.3000],\n",
" [ 3.9000, 3.8000],\n",
" [ 5.4000, 5.3000],\n",
" [ 6.9000, 6.8000],\n",
" [ 8.4000, 8.3000],\n",
" [ 9.9000, 9.8000],\n",
" [11.4000, 11.3000],\n",
" [12.9000, 12.8000],\n",
" [14.4000, 14.3000],\n",
" [15.9000, 15.8000]])\n",
"\n",
"Difference (should be small):\n",
"tensor([[-1.4000e+00, -1.3000e+00],\n",
" [-1.5000e+00, -1.5000e+00],\n",
" [-9.0000e-02, -2.6000e-01],\n",
" [ 3.3990e+00, 2.8870e+00],\n",
" [ 9.7126e+00, 8.5991e+00],\n",
" [ 1.9851e+01, 1.7776e+01],\n",
" [ 3.5164e+01, 3.1636e+01],\n",
" [ 5.7475e+01, 5.1831e+01],\n",
" [ 8.9253e+01, 8.0592e+01],\n",
" [ 1.3383e+02, 1.2094e+02]])\n"
]
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.linalg as linalg\n",
"\n",
"# Initial setup\n",
"timesteps = 10\n",
"input_size = 2 # 2 layers\n",
"hidden_size = 2 # 2 neurons (one in each layer)\n",
"\n",
"# Define the interaction matrix B\n",
"B = torch.tensor([[0.0, 0.5], [0.5, 0.0]]) # Interaction between the two layers\n",
"assert B.shape == (hidden_size, hidden_size), f\"Expected B to have shape {(hidden_size, hidden_size)}, got {B.shape}\"\n",
"\n",
"# Initial states\n",
"h_0 = torch.tensor([1.0, 1.0]) # Initial state for both layers\n",
"assert h_0.shape == (hidden_size,), f\"Expected h_0 to have shape {(hidden_size,)}, got {h_0.shape}\"\n",
"\n",
"# Define decay matrix D (diagonal, no interaction, just decay)\n",
"decay_factors = torch.tensor([0.9, 0.8]) # Decay constants for the two layers\n",
"D = torch.diag(decay_factors)\n",
"assert D.shape == (hidden_size, hidden_size), f\"Expected D to have shape {(hidden_size, hidden_size)}, got {D.shape}\"\n",
"\n",
"# Define input currents over time (just for example purposes, static input here)\n",
"I = torch.ones((timesteps, input_size))\n",
"assert I.shape == (timesteps, input_size), f\"Expected I to have shape {(timesteps, input_size)}, got {I.shape}\"\n",
"\n",
"# Iterative Approach (Ground Truth)\n",
"def iterative_computation(h_0, B, D, I, timesteps):\n",
" h = h_0.clone()\n",
" states = [h_0] # Store the states over time\n",
" for t in range(1, timesteps):\n",
" # Compute new state with interaction, decay, and input\n",
" h = B @ h + D @ h + torch.sum(I[:t], dim=0) # Interaction + decay + cumulative input\n",
" assert h.shape == (hidden_size,), f\"Expected h to have shape {(hidden_size,)}, got {h.shape}\"\n",
" states.append(h)\n",
" result = torch.stack(states) # Return all states over time\n",
" assert result.shape == (timesteps, hidden_size), f\"Expected result to have shape {(timesteps, hidden_size)}, got {result.shape}\"\n",
" return result\n",
"\n",
"# Parallelized Approach Using Matrix Powers (Fully Parallel)\n",
"def parallel_computation_matrix_powers(h_0, B, D, I, timesteps):\n",
" # Stack powers of B to simulate interaction across timesteps\n",
" B_powers = torch.stack([torch.matrix_power(B, t) for t in range(timesteps)], dim=0)\n",
" assert B_powers.shape == (timesteps, hidden_size, hidden_size), f\"Expected B_powers to have shape {(timesteps, hidden_size, hidden_size)}, got {B_powers.shape}\"\n",
" \n",
" # Compute the effect of the interaction across all timesteps in parallel\n",
" interaction_contributions = torch.matmul(B_powers, h_0.unsqueeze(1)).squeeze(2)\n",
" assert interaction_contributions.shape == (timesteps, hidden_size), f\"Expected interaction_contributions to have shape {(timesteps, hidden_size)}, got {interaction_contributions.shape}\"\n",
" \n",
" # Compute decay contributions for all timesteps\n",
" decay_contributions = torch.matmul(h_0.unsqueeze(0).repeat(timesteps, 1), D.T)\n",
" assert decay_contributions.shape == (timesteps, hidden_size), f\"Expected decay_contributions to have shape {(timesteps, hidden_size)}, got {decay_contributions.shape}\"\n",
" \n",
" # Compute cumulative sum of input contributions\n",
" input_contributions = torch.cumsum(I, dim=0)\n",
" assert input_contributions.shape == (timesteps, hidden_size), f\"Expected input_contributions to have shape {(timesteps, hidden_size)}, got {input_contributions.shape}\"\n",
" \n",
" # Combine all contributions\n",
" final_states = decay_contributions + input_contributions + interaction_contributions\n",
" assert final_states.shape == (timesteps, hidden_size), f\"Expected final_states to have shape {(timesteps, hidden_size)}, got {final_states.shape}\"\n",
" \n",
" return final_states\n",
"\n",
"# Run the two approaches\n",
"ground_truth = iterative_computation(h_0, B, D, I, timesteps)\n",
"parallel_result = parallel_computation_matrix_powers(h_0, B, D, I, timesteps)\n",
"\n",
"# Compare the results\n",
"print(\"Ground Truth (Iterative):\")\n",
"print(ground_truth)\n",
"\n",
"print(\"\\nParallel Computation (Matrix Powers):\")\n",
"print(parallel_result)\n",
"\n",
"# Difference between the two\n",
"print(\"\\nDifference (should be small):\")\n",
"print(ground_truth - parallel_result)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 39d6c88

Please sign in to comment.