-
Notifications
You must be signed in to change notification settings - Fork 232
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
398c7c4
commit 39d6c88
Showing
15 changed files
with
3,974 additions
and
1 deletion.
There are no files selected for viewing
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"ename": "RuntimeError", | ||
"evalue": "Given groups=1, weight of size [1, 1, 10], expected input[1, 2, 10] to have 1 channels, but got 2 channels instead", | ||
"output_type": "error", | ||
"traceback": [ | ||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | ||
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", | ||
"Cell \u001b[0;32mIn[2], line 34\u001b[0m\n\u001b[1;32m 31\u001b[0m decay_filter \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mexp(\u001b[39m-\u001b[39mtorch\u001b[39m.\u001b[39marange(timesteps)\u001b[39m.\u001b[39mfloat())\u001b[39m.\u001b[39mview(\u001b[39m1\u001b[39m, \u001b[39m1\u001b[39m, \u001b[39m-\u001b[39m\u001b[39m1\u001b[39m) \u001b[39m# Shape [1, 1, timesteps]\u001b[39;00m\n\u001b[1;32m 33\u001b[0m \u001b[39m# Perform convolution over all layers at once\u001b[39;00m\n\u001b[0;32m---> 34\u001b[0m input_contribution \u001b[39m=\u001b[39m F\u001b[39m.\u001b[39;49mconv1d(inputs, decay_filter, padding\u001b[39m=\u001b[39;49mtimesteps\u001b[39m-\u001b[39;49m\u001b[39m1\u001b[39;49m)\n\u001b[1;32m 36\u001b[0m \u001b[39m# Combine decayed initial state and input contributions\u001b[39;00m\n\u001b[1;32m 37\u001b[0m h_total \u001b[39m=\u001b[39m h_exp\u001b[39m.\u001b[39munsqueeze(\u001b[39m1\u001b[39m) \u001b[39m+\u001b[39m input_contribution\u001b[39m.\u001b[39mpermute(\u001b[39m0\u001b[39m, \u001b[39m2\u001b[39m, \u001b[39m1\u001b[39m) \u001b[39m# Combine along the time axis\u001b[39;00m\n", | ||
"\u001b[0;31mRuntimeError\u001b[0m: Given groups=1, weight of size [1, 1, 10], expected input[1, 2, 10] to have 1 channels, but got 2 channels instead" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import torch\n", | ||
"import torch.nn.functional as F\n", | ||
"\n", | ||
"# Network parameters\n", | ||
"timesteps = 10\n", | ||
"input_size = 2 # Two layers\n", | ||
"hidden_size = 2 # Two layers\n", | ||
"batch_size = 1\n", | ||
"\n", | ||
"# Initialize activations and inputs for two layers\n", | ||
"h_init = torch.randn(batch_size, hidden_size) # Initial activations h_1(0), h_2(0)\n", | ||
"inputs = torch.randn(batch_size, timesteps, hidden_size) # Inputs I_1(t), I_2(t)\n", | ||
"\n", | ||
"# Define the connection matrix A (self-decay + inter-layer connections)\n", | ||
"# A = [[alpha1, beta1],\n", | ||
"# [beta2, alpha2]]\n", | ||
"A = torch.tensor([[0.9, 0.1], # alpha1 = 0.9, beta1 = 0.1\n", | ||
" [0.2, 0.8]]) # beta2 = 0.2, alpha2 = 0.8\n", | ||
"\n", | ||
"# Step 1: Exponentially decay the initial state using matrix exponential\n", | ||
"# h(t) = exp(A * t) * h(0)\n", | ||
"exp_A = torch.matrix_exp(A) # Matrix exponential of A\n", | ||
"h_exp = torch.matmul(exp_A, h_init.T).T # Decayed initial state\n", | ||
"\n", | ||
"# Step 2: Fully parallelized convolution to handle the input contribution over time\n", | ||
"\n", | ||
"# Reshape inputs to be batch x channels x timesteps for conv1d\n", | ||
"inputs = inputs.permute(0, 2, 1) # Now it's [batch_size, layers (channels), timesteps]\n", | ||
"\n", | ||
"# Compute the decay filter\n", | ||
"decay_filter = torch.exp(-torch.arange(timesteps).float()).view(1, 1, -1) # Shape [1, 1, timesteps]\n", | ||
"\n", | ||
"# Perform convolution over all layers at once\n", | ||
"input_contribution = F.conv1d(inputs, decay_filter, padding=timesteps-1)\n", | ||
"\n", | ||
"# Combine decayed initial state and input contributions\n", | ||
"h_total = h_exp.unsqueeze(1) + input_contribution.permute(0, 2, 1) # Combine along the time axis\n", | ||
"\n", | ||
"# Output activations for each layer across all time steps\n", | ||
"print(\"Activations across time for both layers:\")\n", | ||
"print(h_total)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 15, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Ground Truth (Iterative):\n", | ||
"tensor([[ 1.0000, 1.0000],\n", | ||
" [ 2.4000, 2.3000],\n", | ||
" [ 5.3100, 5.0400],\n", | ||
" [ 10.2990, 9.6870],\n", | ||
" [ 18.1126, 16.8991],\n", | ||
" [ 29.7509, 27.5756],\n", | ||
" [ 46.5636, 42.9359],\n", | ||
" [ 70.3752, 64.6305],\n", | ||
" [103.6529, 94.8920],\n", | ||
" [149.7336, 136.7401]])\n", | ||
"\n", | ||
"Parallel Computation:\n", | ||
"tensor([[ 2.4000, 2.3000],\n", | ||
" [ 3.9000, 3.8000],\n", | ||
" [ 5.4000, 5.3000],\n", | ||
" [ 6.9000, 6.8000],\n", | ||
" [ 8.4000, 8.3000],\n", | ||
" [ 9.9000, 9.8000],\n", | ||
" [11.4000, 11.3000],\n", | ||
" [12.9000, 12.8000],\n", | ||
" [14.4000, 14.3000],\n", | ||
" [15.9000, 15.8000]])\n", | ||
"\n", | ||
"Difference (should be small):\n", | ||
"tensor([[-1.4000e+00, -1.3000e+00],\n", | ||
" [-1.5000e+00, -1.5000e+00],\n", | ||
" [-9.0000e-02, -2.6000e-01],\n", | ||
" [ 3.3990e+00, 2.8870e+00],\n", | ||
" [ 9.7126e+00, 8.5991e+00],\n", | ||
" [ 1.9851e+01, 1.7776e+01],\n", | ||
" [ 3.5164e+01, 3.1636e+01],\n", | ||
" [ 5.7475e+01, 5.1831e+01],\n", | ||
" [ 8.9253e+01, 8.0592e+01],\n", | ||
" [ 1.3383e+02, 1.2094e+02]])\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"import torch.linalg as linalg\n", | ||
"\n", | ||
"# Initial setup\n", | ||
"timesteps = 10\n", | ||
"input_size = 2 # 2 layers\n", | ||
"hidden_size = 2 # 2 neurons (one in each layer)\n", | ||
"\n", | ||
"# Define the interaction matrix B\n", | ||
"B = torch.tensor([[0.0, 0.5], [0.5, 0.0]]) # Interaction between the two layers\n", | ||
"assert B.shape == (hidden_size, hidden_size), f\"Expected B to have shape {(hidden_size, hidden_size)}, got {B.shape}\"\n", | ||
"\n", | ||
"# Initial states\n", | ||
"h_0 = torch.tensor([1.0, 1.0]) # Initial state for both layers\n", | ||
"assert h_0.shape == (hidden_size,), f\"Expected h_0 to have shape {(hidden_size,)}, got {h_0.shape}\"\n", | ||
"\n", | ||
"# Define decay matrix D (diagonal, no interaction, just decay)\n", | ||
"decay_factors = torch.tensor([0.9, 0.8]) # Decay constants for the two layers\n", | ||
"D = torch.diag(decay_factors)\n", | ||
"assert D.shape == (hidden_size, hidden_size), f\"Expected D to have shape {(hidden_size, hidden_size)}, got {D.shape}\"\n", | ||
"\n", | ||
"# Define input currents over time (just for example purposes, static input here)\n", | ||
"I = torch.ones((timesteps, input_size))\n", | ||
"assert I.shape == (timesteps, input_size), f\"Expected I to have shape {(timesteps, input_size)}, got {I.shape}\"\n", | ||
"\n", | ||
"# Iterative Approach (Ground Truth)\n", | ||
"def iterative_computation(h_0, B, D, I, timesteps):\n", | ||
" h = h_0.clone()\n", | ||
" states = [h_0] # Store the states over time\n", | ||
" for t in range(1, timesteps):\n", | ||
" # Compute new state with interaction, decay, and input\n", | ||
" h = B @ h + D @ h + torch.sum(I[:t], dim=0) # Interaction + decay + cumulative input\n", | ||
" assert h.shape == (hidden_size,), f\"Expected h to have shape {(hidden_size,)}, got {h.shape}\"\n", | ||
" states.append(h)\n", | ||
" result = torch.stack(states) # Return all states over time\n", | ||
" assert result.shape == (timesteps, hidden_size), f\"Expected result to have shape {(timesteps, hidden_size)}, got {result.shape}\"\n", | ||
" return result\n", | ||
"\n", | ||
"# Parallelized Approach Using Matrix Powers (Fully Parallel)\n", | ||
"def parallel_computation_matrix_powers(h_0, B, D, I, timesteps):\n", | ||
" # Stack powers of B to simulate interaction across timesteps\n", | ||
" B_powers = torch.stack([torch.matrix_power(B, t) for t in range(timesteps)], dim=0)\n", | ||
" assert B_powers.shape == (timesteps, hidden_size, hidden_size), f\"Expected B_powers to have shape {(timesteps, hidden_size, hidden_size)}, got {B_powers.shape}\"\n", | ||
" \n", | ||
" # Compute the effect of the interaction across all timesteps in parallel\n", | ||
" interaction_contributions = torch.matmul(B_powers, h_0.unsqueeze(1)).squeeze(2)\n", | ||
" assert interaction_contributions.shape == (timesteps, hidden_size), f\"Expected interaction_contributions to have shape {(timesteps, hidden_size)}, got {interaction_contributions.shape}\"\n", | ||
" \n", | ||
" # Compute decay contributions for all timesteps\n", | ||
" decay_contributions = torch.matmul(h_0.unsqueeze(0).repeat(timesteps, 1), D.T)\n", | ||
" assert decay_contributions.shape == (timesteps, hidden_size), f\"Expected decay_contributions to have shape {(timesteps, hidden_size)}, got {decay_contributions.shape}\"\n", | ||
" \n", | ||
" # Compute cumulative sum of input contributions\n", | ||
" input_contributions = torch.cumsum(I, dim=0)\n", | ||
" assert input_contributions.shape == (timesteps, hidden_size), f\"Expected input_contributions to have shape {(timesteps, hidden_size)}, got {input_contributions.shape}\"\n", | ||
" \n", | ||
" # Combine all contributions\n", | ||
" final_states = decay_contributions + input_contributions + interaction_contributions\n", | ||
" assert final_states.shape == (timesteps, hidden_size), f\"Expected final_states to have shape {(timesteps, hidden_size)}, got {final_states.shape}\"\n", | ||
" \n", | ||
" return final_states\n", | ||
"\n", | ||
"# Run the two approaches\n", | ||
"ground_truth = iterative_computation(h_0, B, D, I, timesteps)\n", | ||
"parallel_result = parallel_computation_matrix_powers(h_0, B, D, I, timesteps)\n", | ||
"\n", | ||
"# Compare the results\n", | ||
"print(\"Ground Truth (Iterative):\")\n", | ||
"print(ground_truth)\n", | ||
"\n", | ||
"print(\"\\nParallel Computation (Matrix Powers):\")\n", | ||
"print(parallel_result)\n", | ||
"\n", | ||
"# Difference between the two\n", | ||
"print(\"\\nDifference (should be small):\")\n", | ||
"print(ground_truth - parallel_result)\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "env", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.19" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.