Skip to content

Commit

Permalink
Revert "Add Numpy / Torch / Triton differentiation" (#20)
Browse files Browse the repository at this point in the history
This reverts commit e5e549d.
  • Loading branch information
msaroufim authored Nov 12, 2024
1 parent e0e5ddb commit d9eb39f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 44 deletions.
20 changes: 2 additions & 18 deletions .github/workflows/train_workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,8 @@ jobs:

- name: Install dependencies
run: |
# Check if 'import numpy' is in any Python file
if grep -rE "(import numpy|from numpy)" train.py; then
echo "Numpy detected, installing numpy"
pip install numpy
fi
# Check if 'import torch' is in any Python file
if grep -rE "(import torch|from torch)" train.py; then
echo "PyTorch detected, installing torch"
pip install torch
fi
# Check if 'import triton' is in any Python file
if grep -rE "(import triton|from triton)" train.py; then
echo "Triton detected, installing triton"
pip install triton
fi
pip install numpy
pip install torch
- name: Create training script
shell: python
Expand Down
28 changes: 2 additions & 26 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,8 @@
import triton.language as tl
import triton
import torch


@triton.jit
def vector_add_kernel(A, B, C, N, BLOCK_SIZE: tl.constexpr):
# Get the unique program ID for each block
pid = tl.program_id(0)

# Calculate the start index for each block
start = pid * BLOCK_SIZE

# Load data from A and B into registers for vector addition
offset = start + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + offset, mask=offset < N) # Load elements from A
b = tl.load(B + offset, mask=offset < N) # Load elements from B

# Perform element-wise addition
c = a + b

# Store the result back into C
tl.store(C + offset, c, mask=offset < N)


a = torch.Tensor([1, 2, 3, 4, 5]).cuda()
b = torch.Tensor([1, 2, 3, 4, 5]).cuda()
b= torch.Tensor([1, 2, 3, 4, 5]).cuda()

print(a)
print(b)
print(a + b)

print(a + b)

0 comments on commit d9eb39f

Please sign in to comment.