-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1. Generalize MakeTuple in tops_op. 2. Generalize make_const in enflame codegen. 3. Add sin, cos, erf, split for tops. 4. Format Python code in dicp tops.
- Loading branch information
1 parent
870e796
commit 47d0f8a
Showing
9 changed files
with
270 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import pytest | ||
from common.utils import ( | ||
torch, | ||
dynamo, | ||
parse_args, | ||
compile_model, | ||
get_device, | ||
Size, | ||
update_dynamo_config, | ||
) | ||
|
||
|
||
class OpModule(torch.nn.Module): | ||
def forward(self, a): | ||
res_default = torch.ops.aten.cos.default(a) | ||
return res_default | ||
|
||
|
||
model = OpModule() | ||
args = parse_args() | ||
compiled_model = compile_model(model, args.backend, args.dynamic) | ||
|
||
|
||
class TestCos(): | ||
@pytest.mark.parametrize("dtype", [torch.float32]) | ||
@pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))]) | ||
@pytest.mark.parametrize("compiled_model", compiled_model) | ||
def test_torch_cos(self, sizes, dtype, compiled_model): | ||
device = get_device() | ||
size = sizes.dynamic if compiled_model.dynamic else sizes.static | ||
input1 = torch.randn(size, dtype=dtype) | ||
|
||
dicp_input1 = input1.to(device) | ||
|
||
output = model(input1) | ||
dynamo.reset() | ||
update_dynamo_config(compiled_model.dynamic) | ||
dicp_output = compiled_model.model(dicp_input1) | ||
|
||
assert torch.allclose(output, dicp_output.cpu(), atol=1e-04, equal_nan=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import pytest | ||
from common.utils import ( | ||
torch, | ||
dynamo, | ||
parse_args, | ||
compile_model, | ||
get_device, | ||
Size, | ||
update_dynamo_config, | ||
) | ||
|
||
|
||
class OpModule(torch.nn.Module): | ||
def forward(self, a): | ||
res_default = torch.ops.aten.erf.default(a) | ||
return res_default | ||
|
||
|
||
model = OpModule() | ||
args = parse_args() | ||
compiled_model = compile_model(model, args.backend, args.dynamic) | ||
|
||
|
||
class TestErf(): | ||
@pytest.mark.parametrize("dtype", [torch.float32]) | ||
@pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))]) | ||
@pytest.mark.parametrize("compiled_model", compiled_model) | ||
def test_torch_erf(self, sizes, dtype, compiled_model): | ||
device = get_device() | ||
size = sizes.dynamic if compiled_model.dynamic else sizes.static | ||
input1 = torch.randn(size, dtype=dtype) | ||
|
||
dicp_input1 = input1.to(device) | ||
|
||
output = model(input1) | ||
dynamo.reset() | ||
update_dynamo_config(compiled_model.dynamic) | ||
dicp_output = compiled_model.model(dicp_input1) | ||
|
||
assert torch.allclose(output, dicp_output.cpu(), atol=1e-04, equal_nan=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import pytest | ||
from common.utils import ( | ||
torch, | ||
dynamo, | ||
parse_args, | ||
compile_model, | ||
get_device, | ||
Size, | ||
update_dynamo_config, | ||
) | ||
|
||
|
||
class OpModule(torch.nn.Module): | ||
def forward(self, a): | ||
res_default = torch.ops.aten.sin.default(a) | ||
return res_default | ||
|
||
|
||
model = OpModule() | ||
args = parse_args() | ||
compiled_model = compile_model(model, args.backend, args.dynamic) | ||
|
||
|
||
class TestSin(): | ||
@pytest.mark.parametrize("dtype", [torch.float32]) | ||
@pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))]) | ||
@pytest.mark.parametrize("compiled_model", compiled_model) | ||
def test_torch_sin(self, sizes, dtype, compiled_model): | ||
device = get_device() | ||
size = sizes.dynamic if compiled_model.dynamic else sizes.static | ||
input1 = torch.randn(size, dtype=dtype) | ||
|
||
dicp_input1 = input1.to(device) | ||
|
||
output = model(input1) | ||
dynamo.reset() | ||
update_dynamo_config(compiled_model.dynamic) | ||
dicp_output = compiled_model.model(dicp_input1) | ||
|
||
assert torch.allclose(output, dicp_output.cpu(), equal_nan=True) |
Oops, something went wrong.