Skip to content

Commit

Permalink
Add Numpy / Torch / Triton differentiation (#18)
Browse files Browse the repository at this point in the history
* Add Flags for installing different deps

* Handle the "from ..." case for imports"

---------

Co-authored-by: Alex Zhang <[email protected]>
Co-authored-by: Mark Saroufim <[email protected]>
  • Loading branch information
3 people authored Nov 12, 2024
1 parent fd10cef commit e5e549d
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
20 changes: 18 additions & 2 deletions .github/workflows/train_workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,24 @@ jobs:

- name: Install dependencies
run: |
pip install numpy
pip install torch
# Check if 'import numpy' is in any Python file
if grep -rE "(import numpy|from numpy)" train.py; then
echo "Numpy detected, installing numpy"
pip install numpy
fi
# Check if 'import torch' is in any Python file
if grep -rE "(import torch|from torch)" train.py; then
echo "PyTorch detected, installing torch"
pip install torch
fi
# Check if 'import triton' is in any Python file
if grep -rE "(import triton|from triton)" train.py; then
echo "Triton detected, installing triton"
pip install triton
fi
- name: Create training script
shell: python
Expand Down
28 changes: 26 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,32 @@
import triton.language as tl
import triton
import torch


@triton.jit
def vector_add_kernel(A, B, C, N, BLOCK_SIZE: tl.constexpr):
# Get the unique program ID for each block
pid = tl.program_id(0)

# Calculate the start index for each block
start = pid * BLOCK_SIZE

# Load data from A and B into registers for vector addition
offset = start + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + offset, mask=offset < N) # Load elements from A
b = tl.load(B + offset, mask=offset < N) # Load elements from B

# Perform element-wise addition
c = a + b

# Store the result back into C
tl.store(C + offset, c, mask=offset < N)


a = torch.Tensor([1, 2, 3, 4, 5]).cuda()
b= torch.Tensor([1, 2, 3, 4, 5]).cuda()
b = torch.Tensor([1, 2, 3, 4, 5]).cuda()

print(a)
print(b)
print(a + b)
print(a + b)

0 comments on commit e5e549d

Please sign in to comment.