Skip to content

Commit

Permalink
make conv2d out at right memory-format (DeepLink-org#502)
Browse files Browse the repository at this point in the history
  • Loading branch information
wyz5864 authored and brianlcy123 committed Dec 21, 2023
1 parent 4d19db6 commit 72e11f8
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
6 changes: 3 additions & 3 deletions dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@
int64_t out_height = (height + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1;
int64_t out_width = (width + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1;
c10::SmallVector<int64_t, 4> output_size = {batch_size, out_channel, out_height, out_width};
at::Tensor out = at::empty(output_size, input.options());
at::Tensor out = at::empty(output_size, input.options().memory_format(input.suggest_memory_format()));
interface: diopiConvolution2d(&context, out, input, weight, bias, stride, padding, dilation, groups)

- schema: "convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)"
Expand All @@ -527,10 +527,10 @@
at::Tensor grad_bias;
std::vector<int64_t> bias_sizes;
if (output_mask[0]) {
grad_input = at::empty(input.sizes(), input.options());
grad_input = at::empty_like(input);
}
if (output_mask[1]) {
grad_weight = at::empty(weight.sizes(), weight.options().dtype(at::kFloat));
grad_weight = at::empty(weight.sizes(), weight.options().dtype(at::kFloat).memory_format(weight.suggest_memory_format()));
}
if (output_mask[2]) {
bias_sizes.push_back(grad_output.size(1));
Expand Down
17 changes: 17 additions & 0 deletions dipu/tests/python/unittests/test_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,23 @@ def test_conv_2d(self):
)
# print("conv2d output compare successfully")

def test_conv2d_nhwc(self):
device = torch.device("dipu")

m = nn.Conv2d(2, 3, 3).to(device=device, memory_format=torch.channels_last)
self.assertTrue(m.weight.is_contiguous(memory_format=torch.channels_last))

x = torch.rand(2, 2, 5, 5).to(device=device, memory_format=torch.channels_last)
x.requires_grad_()
self.assertTrue(x.is_contiguous(memory_format=torch.channels_last))

y = m(x)
self.assertTrue(y.is_contiguous(memory_format=torch.channels_last))

y.backward(torch.rand_like(y))
self.assertTrue(x.grad.is_contiguous(memory_format=torch.channels_last))
self.assertTrue(m.weight.grad.is_contiguous(memory_format=torch.channels_last))


if __name__ == "__main__":
run_tests()

0 comments on commit 72e11f8

Please sign in to comment.