Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Implement MaskedSoftmax
Browse files Browse the repository at this point in the history
  • Loading branch information
danny jang committed Sep 27, 2023
1 parent 21ab240 commit 6eb01b9
Show file tree
Hide file tree
Showing 10 changed files with 512 additions and 3 deletions.
66 changes: 66 additions & 0 deletions benchmarks/benchmark_masked_softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2023 ⓒ Kakao Brain Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import triton
import util

import trident


def masked_softmax(input: torch.Tensor, mask: torch.Tensor, dim: int):
input = torch.where(mask.bool(), float("-inf"), input)
output = torch.nn.functional.softmax(input, dim)

return output


def build_mask(y_size: int, x_size: int, device=None, dtype=None):
mask = torch.randint(0, 2, (y_size, x_size), device=device, dtype=dtype)
mask[0, :] = mask[:, 0] = 0

return mask


@util.report("masked softmax forward", ["x_size"], [128 * i for i in range(1, 21)], {"y_size": 16})
def bench_masked_softmax_forward(y_size, x_size, dtype, backend):
input = torch.randn(y_size, x_size, device="cuda", dtype=dtype)
mask = build_mask(y_size, x_size, "cuda", dtype)

if backend == "torch":
return triton.testing.do_bench_cudagraph(lambda: masked_softmax(input, mask, 1))
else:
return triton.testing.do_bench_cudagraph(lambda: trident.function.masked_softmax(input, mask, 1))


@util.report("masked softmax backward", ["x_size"], [128 * i for i in range(1, 21)], {"y_size": 16})
def bench_masked_softmax_backward(y_size, x_size, dtype, backend):
factory_kwargs = {"device": "cuda", "dtype": dtype}
input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True)
mask = build_mask(y_size, x_size, "cuda", dtype)
grad_output = torch.randn(y_size, x_size, **factory_kwargs)

if backend == "torch":
output = masked_softmax(input, mask, 1)
else:
output = trident.function.masked_softmax(input, mask, 1)

return triton.testing.do_bench_cudagraph(lambda: output.backward(grad_output, retain_graph=True))


def run_benchmark(mode, show_plots, dtype):
if mode == "forward":
bench_masked_softmax_forward.run(print_data=True, show_plots=show_plots, dtype=dtype)
else:
bench_masked_softmax_backward.run(print_data=True, show_plots=show_plots, dtype=dtype)
5 changes: 5 additions & 0 deletions benchmarks/benchmarker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import benchmark_layer_norm
import benchmark_leaky_relu
import benchmark_linear
import benchmark_masked_softmax
import benchmark_max
import benchmark_mean
import benchmark_prelu
Expand Down Expand Up @@ -58,6 +59,7 @@ def print_scenarios():
"layer-norm",
"leaky-relu",
"linear",
"masked-softmax",
"max",
"mean",
"prelu",
Expand Down Expand Up @@ -99,6 +101,8 @@ def run_benchmarks(scenario, mode, show_plots, dtype):
benchmark_leaky_relu.run_benchmark(mode, show_plots, dtype)
elif scenario == "linear":
benchmark_linear.run_benchmark(mode, show_plots, dtype)
elif scenario == "masked-softmax":
benchmark_masked_softmax.run_benchmark(mode, show_plots, dtype)
elif scenario == "max":
benchmark_max.run_benchmark(mode, show_plots, dtype)
elif scenario == "mean":
Expand Down Expand Up @@ -134,6 +138,7 @@ def run_benchmarks(scenario, mode, show_plots, dtype):
benchmark_layer_norm.run_benchmark(mode, show_plots, dtype)
benchmark_leaky_relu.run_benchmark(mode, show_plots, dtype)
benchmark_linear.run_benchmark(mode, show_plots, dtype)
benchmark_masked_softmax.run_benchmark(mode, show_plots, dtype)
benchmark_max.run_benchmark(mode, show_plots, dtype)
benchmark_mean.run_benchmark(mode, show_plots, dtype)
benchmark_prelu.run_benchmark(mode, show_plots, dtype)
Expand Down
77 changes: 77 additions & 0 deletions tests/test_masked_softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2023 ⓒ Kakao Brain Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

import trident
from tests import util


def masked_softmax(input: torch.Tensor, mask: torch.Tensor, dim: int):
input = torch.where(mask.bool(), float("-inf"), input)
output = torch.nn.functional.softmax(input, dim)

return output


def build_mask(y_size: int, x_size: int, device=None):
mask = torch.randint(0, 2, (y_size, x_size), device=device)
mask[0, :] = mask[:, 0] = 0

return mask


@pytest.mark.parametrize("y_size, x_size, dim", [(2, 512, 0), (3, 1000, 1)])
def test_forward(y_size, x_size, dim, device):
input = torch.randn(y_size, x_size, device=device)
mask = build_mask(y_size, x_size, device)

assert util.equal(masked_softmax(input, mask, dim), trident.function.masked_softmax(input, mask, dim))


@pytest.mark.parametrize("y_size, x_size, dim", [(3, 1000, 0), (2, 512, 1)])
def test_backward(y_size, x_size, dim, device):
input = torch.randn(y_size, x_size, device=device)
mask = build_mask(y_size, x_size, device)
grad_output = torch.randn(y_size, x_size, device=device)

def train(func):
i = input.clone()
i.requires_grad = True
func(i, mask, dim).backward(grad_output, retain_graph=True)
return (i.grad,)

(x,) = train(masked_softmax)
(a,) = train(trident.function.masked_softmax)

assert util.equal(x, a)


@pytest.mark.parametrize("y_size, x_size, dim", [(1, 32, 1)])
def test_masked_softmax(y_size, x_size, dim, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True)
mask = build_mask(y_size, x_size, device)
grad_output = torch.randn_like(input)

output = trident.MaskedSoftmax(dim).forward(input, mask)

assert output is not None
assert output.dtype == dtype

output.backward(grad_output)

assert input.grad is not None
assert input.grad.dtype == dtype
8 changes: 5 additions & 3 deletions tests/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ def test_backward(y_size, x_size, dim, device):
input = torch.randn(y_size, x_size, device=device)
grad_output = torch.randn(y_size, x_size, device=device)

def train(func, dim):
def train(func):
i = input.clone()
i.requires_grad = True
func(i, dim).backward(grad_output, retain_graph=True)
return (i.grad,)

(x,) = train(torch.nn.functional.softmax, dim)
(a,) = train(trident.function.softmax, dim)
(x,) = train(torch.nn.functional.softmax)
(a,) = train(trident.function.softmax)

assert util.equal(x, a)

Expand All @@ -50,9 +50,11 @@ def test_softmax(y_size, x_size, dim, device, dtype):
grad_output = torch.randn_like(input)

output = trident.Softmax(dim).forward(input)

assert output is not None
assert output.dtype == dtype

output.backward(grad_output)

assert input.grad is not None
assert input.grad.dtype == dtype
9 changes: 9 additions & 0 deletions trident/function/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,15 @@ def linear(
return operation.Linear.apply(input, weight, bias, use_accelerator)


def masked_softmax(input: torch.Tensor, mask: torch.Tensor, dim: int):
"""
Applies Masked Softmax to an input rescaling them so that an output lie in the range [0,1] and sum to 1.
See MaskedSoftmax for more details.
"""
return operation.MaskedSoftmax.apply(input, mask, dim)


def max(input: torch.Tensor, dim: int):
"""
Returns the max along the specified dimension in an input.
Expand Down
1 change: 1 addition & 0 deletions trident/kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .layer_norm import *
from .leaky_relu import *
from .linear import *
from .masked_softmax import *
from .max import *
from .mean import *
from .prelu import *
Expand Down
Loading

0 comments on commit 6eb01b9

Please sign in to comment.