forked from DeepLink-org/deeplink.framework
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[dicp][tops] Support some ops for stable-diffusion. (DeepLink-org#467)
* Add sin, cos, erf, split. 1. Generalize MakeTuple in tops_op. 2. Generalize make_const in enflame codegen. 3. Add sin, cos, erf, split for tops. 4. Format Python code in dicp tops. * refine code * fix abs test path * clean up code of split. * adjust const op generation. * fix nullptr case in const generation. --------- Co-authored-by: jinminxi104 <[email protected]> Co-authored-by: Reinerzhou <[email protected]>
- Loading branch information
1 parent
a190a80
commit 51978d9
Showing
9 changed files
with
250 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import pytest | ||
from ..common.utils import ( | ||
torch, | ||
dynamo, | ||
parse_args, | ||
compile_model, | ||
get_device, | ||
Size, | ||
update_dynamo_config, | ||
) | ||
|
||
|
||
class OpModule(torch.nn.Module): | ||
def forward(self, a): | ||
res_default = torch.ops.aten.cos.default(a) | ||
return res_default | ||
|
||
|
||
model = OpModule() | ||
args = parse_args() | ||
compiled_model = compile_model(model, args.backend, args.dynamic) | ||
|
||
|
||
class TestCos(): | ||
@pytest.mark.parametrize("dtype", [torch.float32]) | ||
@pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))]) | ||
@pytest.mark.parametrize("compiled_model", compiled_model) | ||
def test_torch_cos(self, sizes, dtype, compiled_model): | ||
device = get_device() | ||
size = sizes.dynamic if compiled_model.dynamic else sizes.static | ||
input1 = torch.randn(size, dtype=dtype) | ||
|
||
dicp_input1 = input1.to(device) | ||
|
||
output = model(input1) | ||
dynamo.reset() | ||
update_dynamo_config(compiled_model.dynamic) | ||
dicp_output = compiled_model.model(dicp_input1) | ||
|
||
assert torch.allclose(output, dicp_output.cpu(), atol=1e-04, equal_nan=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import pytest | ||
from ..common.utils import ( | ||
torch, | ||
dynamo, | ||
parse_args, | ||
compile_model, | ||
get_device, | ||
Size, | ||
update_dynamo_config, | ||
) | ||
|
||
|
||
class OpModule(torch.nn.Module): | ||
def forward(self, a): | ||
res_default = torch.ops.aten.erf.default(a) | ||
return res_default | ||
|
||
|
||
model = OpModule() | ||
args = parse_args() | ||
compiled_model = compile_model(model, args.backend, args.dynamic) | ||
|
||
|
||
class TestErf(): | ||
@pytest.mark.parametrize("dtype", [torch.float32]) | ||
@pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))]) | ||
@pytest.mark.parametrize("compiled_model", compiled_model) | ||
def test_torch_erf(self, sizes, dtype, compiled_model): | ||
device = get_device() | ||
size = sizes.dynamic if compiled_model.dynamic else sizes.static | ||
input1 = torch.randn(size, dtype=dtype) | ||
|
||
dicp_input1 = input1.to(device) | ||
|
||
output = model(input1) | ||
dynamo.reset() | ||
update_dynamo_config(compiled_model.dynamic) | ||
dicp_output = compiled_model.model(dicp_input1) | ||
|
||
assert torch.allclose(output, dicp_output.cpu(), atol=1e-04, equal_nan=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import pytest | ||
from ..common.utils import ( | ||
torch, | ||
dynamo, | ||
parse_args, | ||
compile_model, | ||
get_device, | ||
Size, | ||
update_dynamo_config, | ||
) | ||
|
||
|
||
class OpModule(torch.nn.Module): | ||
def forward(self, a): | ||
res_default = torch.ops.aten.sin.default(a) | ||
return res_default | ||
|
||
|
||
model = OpModule() | ||
args = parse_args() | ||
compiled_model = compile_model(model, args.backend, args.dynamic) | ||
|
||
|
||
class TestSin(): | ||
@pytest.mark.parametrize("dtype", [torch.float32]) | ||
@pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))]) | ||
@pytest.mark.parametrize("compiled_model", compiled_model) | ||
def test_torch_sin(self, sizes, dtype, compiled_model): | ||
device = get_device() | ||
size = sizes.dynamic if compiled_model.dynamic else sizes.static | ||
input1 = torch.randn(size, dtype=dtype) | ||
|
||
dicp_input1 = input1.to(device) | ||
|
||
output = model(input1) | ||
dynamo.reset() | ||
update_dynamo_config(compiled_model.dynamic) | ||
dicp_output = compiled_model.model(dicp_input1) | ||
|
||
assert torch.allclose(output, dicp_output.cpu(), equal_nan=True) |
Oops, something went wrong.