Skip to content

Commit

Permalink
[Kernel] Register punica ops directly (#10522)
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <[email protected]>
  • Loading branch information
jeejeelee authored Nov 21, 2024
1 parent da7e702 commit 2385b60
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 24 deletions.
23 changes: 17 additions & 6 deletions tests/lora/test_punica_variation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import pytest
import torch

from vllm.lora.ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
# Enable custom op register
import vllm.lora.ops.bgmv_expand
import vllm.lora.ops.bgmv_expand_slice
import vllm.lora.ops.bgmv_shrink
import vllm.lora.ops.sgmv_expand
import vllm.lora.ops.sgmv_expand_slice
import vllm.lora.ops.sgmv_shrink # noqa: F401
from vllm.platforms import current_platform

from .utils import (generate_data, generate_data_for_expand_nslices,
Expand All @@ -37,6 +38,16 @@ def assert_close(a, b):
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)


# Unlike test_punica_sizes.py, we directly utilize custom op for
# testing, which verifies the correct registration of these ops.
bgmv_expand = torch.ops.vllm.bgmv_expand
bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice
bgmv_shrink = torch.ops.vllm.bgmv_shrink
sgmv_expand = torch.ops.vllm.sgmv_expand
sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice
sgmv_shrink = torch.ops.vllm.sgmv_shrink


@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
Expand Down
23 changes: 20 additions & 3 deletions vllm/lora/ops/bgmv_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import triton
import triton.language as tl

from vllm.utils import direct_register_custom_op

from .utils import get_lora_op_configs


Expand Down Expand Up @@ -162,9 +164,24 @@ def _bgmv_expand(
return


def bgmv_expand_fake(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
) -> None:
return


try:
bgmv_expand = torch.library.custom_op("lora::bgmv_expand",
_bgmv_expand,
mutates_args=["output_tensor"])
direct_register_custom_op(
op_name="bgmv_expand",
op_func=_bgmv_expand,
mutates_args=["output_tensor"],
fake_impl=bgmv_expand_fake,
)
bgmv_expand = torch.ops.vllm.bgmv_expand

except AttributeError:
bgmv_expand = _bgmv_expand
25 changes: 22 additions & 3 deletions vllm/lora/ops/bgmv_expand_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import triton
import triton.language as tl

from vllm.utils import direct_register_custom_op

from .utils import get_lora_op_configs


Expand Down Expand Up @@ -179,9 +181,26 @@ def _bgmv_expand_slice(
return


def bgmv_expand_slice_fake(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
) -> None:
return


try:
bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice",
_bgmv_expand_slice,
mutates_args=["output_tensor"])
direct_register_custom_op(
op_name="bgmv_expand_slice",
op_func=_bgmv_expand_slice,
mutates_args=["output_tensor"],
fake_impl=bgmv_expand_slice_fake,
)
bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice

except AttributeError:
bgmv_expand_slice = _bgmv_expand_slice
23 changes: 20 additions & 3 deletions vllm/lora/ops/bgmv_shrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import triton
import triton.language as tl

from vllm.utils import direct_register_custom_op

from .utils import get_lora_op_configs


Expand Down Expand Up @@ -142,9 +144,24 @@ def _bgmv_shrink(
return


def bgmv_shrink_fake(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
) -> None:
return


try:
bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink",
_bgmv_shrink,
mutates_args=["output_tensor"])
direct_register_custom_op(
op_name="bgmv_shrink",
op_func=_bgmv_shrink,
mutates_args=["output_tensor"],
fake_impl=bgmv_shrink_fake,
)
bgmv_shrink = torch.ops.vllm.bgmv_shrink

except AttributeError:
bgmv_shrink = _bgmv_shrink
29 changes: 26 additions & 3 deletions vllm/lora/ops/sgmv_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import triton
import triton.language as tl

from vllm.utils import direct_register_custom_op


@triton.jit
def _sgmv_expand_kernel(
Expand Down Expand Up @@ -196,9 +198,30 @@ def _sgmv_expand(
return


def sgmv_expand_fake(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False,
) -> None:
return


try:
sgmv_expand = torch.library.custom_op("lora::sgmv_expand",
_sgmv_expand,
mutates_args=["output_tensor"])

direct_register_custom_op(
op_name="sgmv_expand",
op_func=_sgmv_expand,
mutates_args=["output_tensor"],
fake_impl=sgmv_expand_fake,
)
sgmv_expand = torch.ops.vllm.sgmv_expand

except AttributeError:
sgmv_expand = _sgmv_expand
30 changes: 27 additions & 3 deletions vllm/lora/ops/sgmv_expand_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import triton
import triton.language as tl

from vllm.utils import direct_register_custom_op


@triton.jit
def _sgmv_expand_slice_kernel(
Expand Down Expand Up @@ -209,9 +211,31 @@ def _sgmv_expand_slice(
return


def sgmv_expand_slice_fake(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
slice_offset: int,
slice_size: int,
add_inputs: bool = False,
) -> None:
return


try:
sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice",
_sgmv_expand_slice,
mutates_args=["output_tensor"])
direct_register_custom_op(
op_name="sgmv_expand_slice",
op_func=_sgmv_expand_slice,
mutates_args=["output_tensor"],
fake_impl=sgmv_expand_slice_fake,
)
sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice

except AttributeError:
sgmv_expand_slice = _sgmv_expand_slice
28 changes: 25 additions & 3 deletions vllm/lora/ops/sgmv_shrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import triton
import triton.language as tl

from vllm.utils import direct_register_custom_op


@triton.jit
def _sgmv_shrink_kernel(
Expand Down Expand Up @@ -190,9 +192,29 @@ def _sgmv_shrink(
return


def sgmv_shrink_fake(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
scaling: float,
) -> None:
return


try:
sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink",
_sgmv_shrink,
mutates_args=["output_tensor"])
direct_register_custom_op(
op_name="sgmv_shrink",
op_func=_sgmv_shrink,
mutates_args=["output_tensor"],
fake_impl=sgmv_shrink_fake,
)
sgmv_shrink = torch.ops.vllm.sgmv_shrink

except AttributeError:
sgmv_shrink = _sgmv_shrink

0 comments on commit 2385b60

Please sign in to comment.