From 04073ec4c25b2e972218390ebcb7967623d488f1 Mon Sep 17 00:00:00 2001 From: "danny.jang" Date: Wed, 4 Oct 2023 11:43:30 +0900 Subject: [PATCH] Implement MaskedSoftmax --- benchmarks/benchmark_masked_softmax.py | 66 ++++++++ benchmarks/benchmarker.py | 5 + tests/test_masked_softmax.py | 77 ++++++++++ tests/test_softmax.py | 8 +- trident/function/function.py | 9 ++ trident/kernel/__init__.py | 1 + trident/kernel/masked_softmax.py | 199 +++++++++++++++++++++++++ trident/module.py | 35 +++++ trident/operation/__init__.py | 1 + trident/operation/masked_softmax.py | 114 ++++++++++++++ 10 files changed, 512 insertions(+), 3 deletions(-) create mode 100644 benchmarks/benchmark_masked_softmax.py create mode 100644 tests/test_masked_softmax.py create mode 100644 trident/kernel/masked_softmax.py create mode 100644 trident/operation/masked_softmax.py diff --git a/benchmarks/benchmark_masked_softmax.py b/benchmarks/benchmark_masked_softmax.py new file mode 100644 index 00000000..f239f9d7 --- /dev/null +++ b/benchmarks/benchmark_masked_softmax.py @@ -0,0 +1,66 @@ +# Copyright 2023 ⓒ Kakao Brain Corp. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import triton +import util + +import trident + + +def masked_softmax(input: torch.Tensor, mask: torch.Tensor, dim: int): + input = torch.where(mask.bool(), float("-inf"), input) + output = torch.nn.functional.softmax(input, dim) + + return output + + +def build_mask(y_size: int, x_size: int, device=None, dtype=None): + mask = torch.randint(0, 2, (y_size, x_size), device=device, dtype=dtype) + mask[0, :] = mask[:, 0] = 0 + + return mask + + +@util.report("masked softmax forward", ["x_size"], [128 * i for i in range(1, 21)], {"y_size": 16}) +def bench_masked_softmax_forward(y_size, x_size, dtype, backend): + input = torch.randn(y_size, x_size, device="cuda", dtype=dtype) + mask = build_mask(y_size, x_size, "cuda", dtype) + + if backend == "torch": + return triton.testing.do_bench_cudagraph(lambda: masked_softmax(input, mask, 1)) + else: + return triton.testing.do_bench_cudagraph(lambda: trident.function.masked_softmax(input, mask, 1)) + + +@util.report("masked softmax backward", ["x_size"], [128 * i for i in range(1, 21)], {"y_size": 16}) +def bench_masked_softmax_backward(y_size, x_size, dtype, backend): + factory_kwargs = {"device": "cuda", "dtype": dtype} + input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True) + mask = build_mask(y_size, x_size, "cuda", dtype) + grad_output = torch.randn(y_size, x_size, **factory_kwargs) + + if backend == "torch": + output = masked_softmax(input, mask, 1) + else: + output = trident.function.masked_softmax(input, mask, 1) + + return triton.testing.do_bench_cudagraph(lambda: output.backward(grad_output, retain_graph=True)) + + +def run_benchmark(mode, show_plots, dtype): + if mode == "forward": + bench_masked_softmax_forward.run(print_data=True, show_plots=show_plots, dtype=dtype) + else: + bench_masked_softmax_backward.run(print_data=True, show_plots=show_plots, dtype=dtype) diff --git a/benchmarks/benchmarker.py b/benchmarks/benchmarker.py index 377f4821..4521b4fa 100644 --- a/benchmarks/benchmarker.py +++ b/benchmarks/benchmarker.py @@ -26,6 +26,7 @@ import benchmark_layer_norm import benchmark_leaky_relu import benchmark_linear +import benchmark_masked_softmax import benchmark_max import benchmark_mean import benchmark_prelu @@ -58,6 +59,7 @@ def print_scenarios(): "layer-norm", "leaky-relu", "linear", + "masked-softmax", "max", "mean", "prelu", @@ -99,6 +101,8 @@ def run_benchmarks(scenario, mode, show_plots, dtype): benchmark_leaky_relu.run_benchmark(mode, show_plots, dtype) elif scenario == "linear": benchmark_linear.run_benchmark(mode, show_plots, dtype) + elif scenario == "masked-softmax": + benchmark_masked_softmax.run_benchmark(mode, show_plots, dtype) elif scenario == "max": benchmark_max.run_benchmark(mode, show_plots, dtype) elif scenario == "mean": @@ -134,6 +138,7 @@ def run_benchmarks(scenario, mode, show_plots, dtype): benchmark_layer_norm.run_benchmark(mode, show_plots, dtype) benchmark_leaky_relu.run_benchmark(mode, show_plots, dtype) benchmark_linear.run_benchmark(mode, show_plots, dtype) + benchmark_masked_softmax.run_benchmark(mode, show_plots, dtype) benchmark_max.run_benchmark(mode, show_plots, dtype) benchmark_mean.run_benchmark(mode, show_plots, dtype) benchmark_prelu.run_benchmark(mode, show_plots, dtype) diff --git a/tests/test_masked_softmax.py b/tests/test_masked_softmax.py new file mode 100644 index 00000000..5a0a34a7 --- /dev/null +++ b/tests/test_masked_softmax.py @@ -0,0 +1,77 @@ +# Copyright 2023 ⓒ Kakao Brain Corp. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +import trident +from tests import util + + +def masked_softmax(input: torch.Tensor, mask: torch.Tensor, dim: int): + input = torch.where(mask.bool(), float("-inf"), input) + output = torch.nn.functional.softmax(input, dim) + + return output + + +def build_mask(y_size: int, x_size: int, device=None): + mask = torch.randint(0, 2, (y_size, x_size), device=device) + mask[0, :] = mask[:, 0] = 0 + + return mask + + +@pytest.mark.parametrize("y_size, x_size, dim", [(2, 512, 0), (3, 1000, 1)]) +def test_forward(y_size, x_size, dim, device): + input = torch.randn(y_size, x_size, device=device) + mask = build_mask(y_size, x_size, device) + + assert util.equal(masked_softmax(input, mask, dim), trident.function.masked_softmax(input, mask, dim)) + + +@pytest.mark.parametrize("y_size, x_size, dim", [(3, 1000, 0), (2, 512, 1)]) +def test_backward(y_size, x_size, dim, device): + input = torch.randn(y_size, x_size, device=device) + mask = build_mask(y_size, x_size, device) + grad_output = torch.randn(y_size, x_size, device=device) + + def train(func): + i = input.clone() + i.requires_grad = True + func(i, mask, dim).backward(grad_output, retain_graph=True) + return (i.grad,) + + (x,) = train(masked_softmax) + (a,) = train(trident.function.masked_softmax) + + assert util.equal(x, a) + + +@pytest.mark.parametrize("y_size, x_size, dim", [(1, 32, 1)]) +def test_masked_softmax(y_size, x_size, dim, device, dtype): + factory_kwargs = {"device": device, "dtype": dtype} + input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True) + mask = build_mask(y_size, x_size, device) + grad_output = torch.randn_like(input) + + output = trident.MaskedSoftmax(dim).forward(input, mask) + + assert output is not None + assert output.dtype == dtype + + output.backward(grad_output) + + assert input.grad is not None + assert input.grad.dtype == dtype diff --git a/tests/test_softmax.py b/tests/test_softmax.py index aa92f944..7b072246 100644 --- a/tests/test_softmax.py +++ b/tests/test_softmax.py @@ -31,14 +31,14 @@ def test_backward(y_size, x_size, dim, device): input = torch.randn(y_size, x_size, device=device) grad_output = torch.randn(y_size, x_size, device=device) - def train(func, dim): + def train(func): i = input.clone() i.requires_grad = True func(i, dim).backward(grad_output, retain_graph=True) return (i.grad,) - (x,) = train(torch.nn.functional.softmax, dim) - (a,) = train(trident.function.softmax, dim) + (x,) = train(torch.nn.functional.softmax) + (a,) = train(trident.function.softmax) assert util.equal(x, a) @@ -50,9 +50,11 @@ def test_softmax(y_size, x_size, dim, device, dtype): grad_output = torch.randn_like(input) output = trident.Softmax(dim).forward(input) + assert output is not None assert output.dtype == dtype output.backward(grad_output) + assert input.grad is not None assert input.grad.dtype == dtype diff --git a/trident/function/function.py b/trident/function/function.py index 3a6b16c1..3089e328 100644 --- a/trident/function/function.py +++ b/trident/function/function.py @@ -163,6 +163,15 @@ def linear( return operation.Linear.apply(input, weight, bias, use_accelerator) +def masked_softmax(input: torch.Tensor, mask: torch.Tensor, dim: int): + """ + Applies Masked Softmax to an input rescaling them so that an output lie in the range [0,1] and sum to 1. + + See MaskedSoftmax for more details. + """ + return operation.MaskedSoftmax.apply(input, mask, dim) + + def max(input: torch.Tensor, dim: int): """ Returns the max along the specified dimension in an input. diff --git a/trident/kernel/__init__.py b/trident/kernel/__init__.py index 0c2020e2..b2d34a8d 100644 --- a/trident/kernel/__init__.py +++ b/trident/kernel/__init__.py @@ -24,6 +24,7 @@ from .layer_norm import * from .leaky_relu import * from .linear import * +from .masked_softmax import * from .max import * from .mean import * from .prelu import * diff --git a/trident/kernel/masked_softmax.py b/trident/kernel/masked_softmax.py new file mode 100644 index 00000000..dffd57f9 --- /dev/null +++ b/trident/kernel/masked_softmax.py @@ -0,0 +1,199 @@ +# Copyright 2023 ⓒ Kakao Brain Corp. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import triton +import triton.language as tl + +from trident import language, util + + +class MaskedSoftmax: + @staticmethod + @triton.heuristics({"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"]}) + @triton.jit + def forward( + output_ptr: tl.tensor, + input_ptr: tl.tensor, + y_size: tl.int32, + x_size: tl.int32, + y_stride: tl.int32, + x_stride: tl.int32, + mask_ptr: tl.tensor, + dtype: tl.constexpr, + x_block_size: tl.constexpr, + require_x_boundary_check: tl.constexpr, + ): + y_offset = tl.program_id(0) + + output_block_ptr = tl.make_block_ptr( + output_ptr, + shape=(y_size, x_size), + strides=(y_stride, x_stride), + offsets=(y_offset, 0), + block_shape=(1, x_block_size), + order=(1, 0), + ) + input_block_ptr = tl.make_block_ptr( + input_ptr, + shape=(y_size, x_size), + strides=(y_stride, x_stride), + offsets=(y_offset, 0), + block_shape=(1, x_block_size), + order=(1, 0), + ) + mask_block_ptr = tl.make_block_ptr( + mask_ptr, + shape=(y_size, x_size), + strides=(y_stride, x_stride), + offsets=(y_offset, 0), + block_shape=(1, x_block_size), + order=(1, 0), + ) + + if require_x_boundary_check: + input = tl.load(input_block_ptr, boundary_check=(1,)) + mask = tl.load(mask_block_ptr, boundary_check=(1,)) + condition = tl.arange(0, x_block_size) < x_size + mask = tl.where(condition, mask, 1) + else: + input = tl.load(input_block_ptr) + mask = tl.load(mask_block_ptr) + + input = tl.where(mask > language.eps, float("-inf"), input) + max = tl.max(input, 1) + numerator = tl.math.fast_expf(input - max) + output = numerator / tl.sum(numerator) + + if require_x_boundary_check: + tl.store(output_block_ptr, output.to(dtype), boundary_check=(1,)) + else: + tl.store(output_block_ptr, output.to(dtype)) + + @staticmethod + @triton.heuristics({"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"]}) + @triton.jit + def backward( + grad_input_ptr: tl.tensor, + grad_output_ptr: tl.tensor, + output_ptr: tl.tensor, + delta_ptr: tl.tensor, + y_size: tl.int32, + x_size: tl.int32, + y_stride: tl.int32, + x_stride: tl.int32, + dtype: tl.constexpr, + x_block_size: tl.constexpr, + require_x_boundary_check: tl.constexpr, + ): + y_offset = tl.program_id(0) + + grad_input_block_ptr = tl.make_block_ptr( + grad_input_ptr, + shape=(y_size, x_size), + strides=(y_stride, x_stride), + offsets=(y_offset, 0), + block_shape=(1, x_block_size), + order=(1, 0), + ) + grad_output_block_ptr = tl.make_block_ptr( + grad_output_ptr, + shape=(y_size, x_size), + strides=(y_stride, x_stride), + offsets=(y_offset, 0), + block_shape=(1, x_block_size), + order=(1, 0), + ) + output_block_ptr = tl.make_block_ptr( + output_ptr, + shape=(y_size, x_size), + strides=(y_stride, x_stride), + offsets=(y_offset, 0), + block_shape=(1, x_block_size), + order=(1, 0), + ) + delta_block_ptr = tl.make_block_ptr( + delta_ptr, + shape=(y_size,), + strides=(1,), + offsets=(y_offset,), + block_shape=(1,), + order=(0,), + ) + + if require_x_boundary_check: + output = tl.load(output_block_ptr, boundary_check=(1,)) + grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,)) + else: + output = tl.load(output_block_ptr) + grad_output = tl.load(grad_output_block_ptr) + + delta = tl.load(delta_block_ptr) + grad_input = output * (grad_output - delta) + + if require_x_boundary_check: + tl.store(grad_input_block_ptr, grad_input.to(dtype), boundary_check=(1,)) + else: + tl.store(grad_input_block_ptr, grad_input.to(dtype)) + + @staticmethod + @triton.heuristics({"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"]}) + @triton.jit + def backward_delta( + delta_ptr: tl.tensor, + grad_output_ptr: tl.tensor, + output_ptr: tl.tensor, + y_size: tl.int32, + x_size: tl.int32, + y_stride: tl.int32, + x_stride: tl.int32, + dtype: tl.constexpr, + x_block_size: tl.constexpr, + require_x_boundary_check: tl.constexpr, + ): + y_offset = tl.program_id(0) + + delta_block_ptr = tl.make_block_ptr( + delta_ptr, + shape=(y_size,), + strides=(1,), + offsets=(y_offset,), + block_shape=(1,), + order=(0,), + ) + grad_output_block_ptr = tl.make_block_ptr( + grad_output_ptr, + shape=(y_size, x_size), + strides=(y_stride, x_stride), + offsets=(y_offset, 0), + block_shape=(1, x_block_size), + order=(1, 0), + ) + output_block_ptr = tl.make_block_ptr( + output_ptr, + shape=(y_size, x_size), + strides=(y_stride, x_stride), + offsets=(y_offset, 0), + block_shape=(1, x_block_size), + order=(1, 0), + ) + + if require_x_boundary_check: + grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,), padding_option="zero") + output = tl.load(output_block_ptr, boundary_check=(1,)) + else: + grad_output = tl.load(grad_output_block_ptr) + output = tl.load(output_block_ptr) + + delta = tl.sum(grad_output * output, 1) + tl.store(delta_block_ptr, delta.to(dtype)) diff --git a/trident/module.py b/trident/module.py index fa16f452..d0651d76 100644 --- a/trident/module.py +++ b/trident/module.py @@ -826,6 +826,41 @@ def reset_parameters(self): util.uniform(self.bias, -bound, bound) +class MaskedSoftmax(torch.nn.Module): + def __init__(self, dim: int = None): + """ + Applies Masked Softmax to an input rescaling them so that an output lie in the range [0,1] and sum to 1. + + Args: + dim: A dimension along which Softmax will be computed (so every slice along dim will sum to 1) + """ + super().__init__() + + self.dim = dim + + def forward(self, input: torch.Tensor, mask: torch.Tensor): + """ + Applies Masked Softmax to input. + + Args: + input: an input + mask: a mask + + Returns: + an output with the same dimension and shape as an input with values in the range [0, 1] + """ + return function.masked_softmax(input, mask, self.dim) + + def extra_repr(self): + """ + Set the extra representation of the module. + + Returns: + customized extra information + """ + return f"dim={self.dim}, backend=Trident" + + class Max(torch.nn.Module): def __init__(self, dim: torch.int32): """ diff --git a/trident/operation/__init__.py b/trident/operation/__init__.py index 0c2020e2..b2d34a8d 100644 --- a/trident/operation/__init__.py +++ b/trident/operation/__init__.py @@ -24,6 +24,7 @@ from .layer_norm import * from .leaky_relu import * from .linear import * +from .masked_softmax import * from .max import * from .mean import * from .prelu import * diff --git a/trident/operation/masked_softmax.py b/trident/operation/masked_softmax.py new file mode 100644 index 00000000..52c7ed91 --- /dev/null +++ b/trident/operation/masked_softmax.py @@ -0,0 +1,114 @@ +# Copyright 2023 ⓒ Kakao Brain Corp. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch +import triton + +from trident import kernel, util + + +class MaskedSoftmax(torch.autograd.Function): + @staticmethod + def forward(ctx: Any, *args: Any, **kwargs: Any): + input, mask, dim = args + + util.push_trace("MaskedSoftmax.__forward") + output = MaskedSoftmax.__forward(input, mask, dim) + util.pop_trace() + + ctx.save_for_backward(output) + ctx.dim = dim + + return output + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any): + (grad_output,) = grad_outputs + (output,) = ctx.saved_tensors + + util.push_trace("MaskedSoftmax.__backward") + grad_input = MaskedSoftmax.__backward(grad_output, output, ctx.dim) + util.pop_trace() + + return grad_input, None, None + + @staticmethod + def __forward(input: torch.Tensor, mask: torch.Tensor, dim: torch.int32): + y_size, x_size, y_stride, x_stride = util.size_and_stride(input, dim) + output = torch.empty_like(input) + + def grid(meta): + return (y_size,) + + util.push_trace("kernel.MaskedSoftmax.forward") + kernel.MaskedSoftmax.forward[grid]( + output, + input, + y_size, + x_size, + y_stride, + x_stride, + mask, + util.dtype(output.dtype), + triton.next_power_of_2(x_size), + ) + util.pop_trace() + + return output + + @staticmethod + def __backward(grad_output: torch.Tensor, output: torch.Tensor, dim: torch.int32): + factory_kwargs = {"device": output.device, "dtype": output.dtype} + y_size, x_size, y_stride, x_stride = util.size_and_stride(output, dim) + delta = torch.empty(y_size, **factory_kwargs) + grad_input = torch.empty_like(output) + + def grid(meta): + return (y_size,) + + util.push_trace("kernel.MaskedSoftmax.backward_delta") + kernel.MaskedSoftmax.backward_delta[grid]( + delta, + grad_output, + output, + y_size, + x_size, + y_stride, + x_stride, + util.dtype(delta.dtype), + triton.next_power_of_2(x_size), + ) + util.pop_trace() + + def grid(meta): + return (y_size,) + + util.push_trace("kernel.MaskedSoftmax.backward") + kernel.MaskedSoftmax.backward[grid]( + grad_input, + grad_output, + output, + delta, + y_size, + x_size, + y_stride, + x_stride, + util.dtype(output.dtype), + triton.next_power_of_2(x_size), + ) + util.pop_trace() + + return grad_input