Skip to content

Commit

Permalink
Avoid COW materialize in nn.functional forward ops (2) (pytorch#121992)
Browse files Browse the repository at this point in the history
Affected ops:
* dropout
* embedding
* embedding_bag
* mutli_head_attention_forward
* grid_sample
* ctc_loss
* nll_loss
* pdist

Pull Request resolved: pytorch#121992
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#122437, pytorch#121991
  • Loading branch information
kurtamohler authored and pytorchmergebot committed Mar 25, 2024
1 parent 55becf0 commit 1989271
Show file tree
Hide file tree
Showing 13 changed files with 98 additions and 107 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ Tensor & embedding_renorm_cpu_(
auto num_indices = indices.numel();

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_renorm_cpu_", [&]() {
auto data_ptr = indices_contig.data_ptr<index_t>();
auto data_ptr = indices_contig.const_data_ptr<index_t>();
auto sorted_indices = std::vector<index_t>(data_ptr, data_ptr + num_indices);
std::sort(sorted_indices.begin(), sorted_indices.end());

Expand Down
62 changes: 31 additions & 31 deletions aten/src/ATen/native/EmbeddingBag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ index_select_add(
index_t padding_idx,
_EmbeddingBagKernelCache* /* fbgemm_kernel_cache */) {
TORCH_CHECK(select_indices.numel() == add_indices.numel());
auto* add_indices_data = add_indices.data_ptr<index_t>();
auto* select_indices_data = select_indices.data_ptr<index_t>();
auto* src_data = src.data_ptr<data_t>();
auto* add_indices_data = add_indices.const_data_ptr<index_t>();
auto* select_indices_data = select_indices.const_data_ptr<index_t>();
auto* src_data = src.const_data_ptr<data_t>();
auto* output_data = output.data_ptr<data_t>();
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
index_t* bag_size_data = nullptr;
Expand Down Expand Up @@ -208,14 +208,14 @@ index_select_add(
index_t padding_idx,
_EmbeddingBagKernelCache* fbgemm_kernel_cache) {
int64_t ddim = src.size(1);
auto* select_indices_data = select_indices.data_ptr<index_t>();
auto* select_indices_data = select_indices.const_data_ptr<index_t>();
auto* output_data = output.data_ptr<data_t>();

if (is_fast_path_index_select(src, output, padding_idx)) {
auto src_contig = src.contiguous();
auto* src_data = src_contig.data_ptr<data_t>();
auto* src_data = src_contig.const_data_ptr<data_t>();
int64_t output_size = offsets.numel() - 1;
auto* offsets_data = offsets.data_ptr<index_t>();
auto* offsets_data = offsets.const_data_ptr<index_t>();
std::vector<index_t> offsets_include_last;

if (include_last_offset) {
Expand Down Expand Up @@ -316,8 +316,8 @@ index_select_add(
#endif
} else {
TORCH_CHECK(select_indices.numel() == add_indices.numel());
auto* src_data = src.data_ptr<data_t>();
auto* add_indices_data = add_indices.data_ptr<index_t>();
auto* src_data = src.const_data_ptr<data_t>();
auto* add_indices_data = add_indices.const_data_ptr<index_t>();
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
index_t* bag_size_data = nullptr;
if (bag_size.defined()) {
Expand Down Expand Up @@ -388,14 +388,14 @@ index_select_add(const Tensor &select_indices,
index_t padding_idx,
_EmbeddingBagKernelCache* fbgemm_kernel_cache) {
int64_t ddim = src.size(1);
auto* select_indices_data = select_indices.data_ptr<index_t>();
auto* select_indices_data = select_indices.const_data_ptr<index_t>();
auto* output_data = output.data_ptr<float>();

if (is_fast_path_index_select(src, output, padding_idx)) {
auto src_contig = src.contiguous();
auto* src_data = src_contig.data_ptr<float>();
auto* src_data = src_contig.const_data_ptr<float>();
int64_t output_size = offsets.numel() - 1;
auto* offsets_data = offsets.data_ptr<index_t>();
auto* offsets_data = offsets.const_data_ptr<index_t>();
std::vector<index_t> offsets_include_last;

if (include_last_offset) {
Expand Down Expand Up @@ -463,8 +463,8 @@ index_select_add(const Tensor &select_indices,
});
} else {
AT_ASSERT(select_indices.numel() == add_indices.numel());
auto* src_data = src.data_ptr<float>();
auto* add_indices_data = add_indices.data_ptr<index_t>();
auto* src_data = src.const_data_ptr<float>();
auto* add_indices_data = add_indices.const_data_ptr<index_t>();
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
index_t* bag_size_data = nullptr;
if (bag_size.defined()) {
Expand Down Expand Up @@ -519,9 +519,9 @@ index_select_scale_add(
index_t padding_idx,
_EmbeddingBagKernelCache* /* fbgemm_kernel_cache */) {
AT_ASSERT(select_indices.numel() == add_indices.numel());
auto* add_indices_data = add_indices.data_ptr<index_t>();
auto* select_indices_data = select_indices.data_ptr<index_t>();
auto* src_data = src.data_ptr<data_t>();
auto* add_indices_data = add_indices.const_data_ptr<index_t>();
auto* select_indices_data = select_indices.const_data_ptr<index_t>();
auto* src_data = src.const_data_ptr<data_t>();
auto* output_data = output.data_ptr<data_t>();
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
index_t* bag_size_data = nullptr;
Expand All @@ -536,7 +536,7 @@ index_select_scale_add(
auto output_stride0 = output.strides()[0];
auto output_stride1 = output.strides()[1];

auto* scale_data = scale.data_ptr<data_t>();
auto* scale_data = scale.const_data_ptr<data_t>();
auto scale_stride = scale.strides()[0];

for (const auto i : c10::irange(numel)) {
Expand Down Expand Up @@ -579,15 +579,15 @@ index_select_scale_add(
index_t padding_idx,
_EmbeddingBagKernelCache* fbgemm_kernel_cache) {
int64_t ddim = src.size(1);
auto* scale_data = scale.data_ptr<data_t>();
auto* select_indices_data = select_indices.data_ptr<index_t>();
auto* scale_data = scale.const_data_ptr<data_t>();
auto* select_indices_data = select_indices.const_data_ptr<index_t>();
auto* output_data = output.data_ptr<data_t>();

if (is_fast_path_index_select_scale(src, scale, output, padding_idx)) {
auto src_contig = src.contiguous();
auto* src_data = src_contig.data_ptr<data_t>();
auto* src_data = src_contig.const_data_ptr<data_t>();
int64_t output_size = offsets.numel() - 1;
auto* offsets_data = offsets.data_ptr<index_t>();
auto* offsets_data = offsets.const_data_ptr<index_t>();
std::vector<index_t> offsets_include_last;

if (include_last_offset) {
Expand Down Expand Up @@ -705,8 +705,8 @@ index_select_scale_add(
#endif
} else {
AT_ASSERT(select_indices.numel() == add_indices.numel());
auto* src_data = src.data_ptr<data_t>();
auto* add_indices_data = add_indices.data_ptr<index_t>();
auto* src_data = src.const_data_ptr<data_t>();
auto* add_indices_data = add_indices.const_data_ptr<index_t>();
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
index_t* bag_size_data = nullptr;
if (bag_size.defined()) {
Expand Down Expand Up @@ -770,15 +770,15 @@ index_select_scale_add(const Tensor &select_indices,
index_t padding_idx,
_EmbeddingBagKernelCache* fbgemm_kernel_cache) {
int64_t ddim = src.size(1);
auto* scale_data = scale.data_ptr<float>();
auto* select_indices_data = select_indices.data_ptr<index_t>();
auto* scale_data = scale.const_data_ptr<float>();
auto* select_indices_data = select_indices.const_data_ptr<index_t>();
auto* output_data = output.data_ptr<float>();

if (is_fast_path_index_select_scale(src, scale, output, padding_idx)) {
auto src_contig = src.contiguous();
auto* src_data = src_contig.data_ptr<float>();
auto* src_data = src_contig.const_data_ptr<float>();
int64_t output_size = offsets.numel() - 1;
auto* offsets_data = offsets.data_ptr<index_t>();
auto* offsets_data = offsets.const_data_ptr<index_t>();
std::vector<index_t> offsets_include_last;

if (include_last_offset) {
Expand Down Expand Up @@ -844,8 +844,8 @@ index_select_scale_add(const Tensor &select_indices,
});
} else {
AT_ASSERT(select_indices.numel() == add_indices.numel());
auto* src_data = src.data_ptr<float>();
auto* add_indices_data = add_indices.data_ptr<index_t>();
auto* src_data = src.const_data_ptr<float>();
auto* add_indices_data = add_indices.const_data_ptr<index_t>();
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
index_t* bag_size_data = nullptr;
if (bag_size.defined()) {
Expand Down Expand Up @@ -1089,7 +1089,7 @@ void embedding_bag_cpu_max_out(
int64_t featureSize = weight.size(1);
int64_t vocab_size = weight.size(0);
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_cpu_max_out", [&] {
auto* indices_data = indices.data_ptr<index_t>();
auto* indices_data = indices.const_data_ptr<index_t>();
auto* offset2bag_data = offset2bag.data_ptr<index_t>();

index_t* max_indices_data = nullptr;
Expand All @@ -1099,7 +1099,7 @@ void embedding_bag_cpu_max_out(
max_indices_stride = max_indices->strides()[0];
}

auto* weight_data = weight.data_ptr<scalar_t>();
auto* weight_data = weight.const_data_ptr<scalar_t>();
auto* output_data = output.data_ptr<scalar_t>();
auto* bag_size_data = bag_size.data_ptr<index_t>();
auto weight_stride0 = weight.strides()[0];
Expand Down
30 changes: 15 additions & 15 deletions aten/src/ATen/native/GridSampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,19 +75,19 @@ namespace {
int64_t out_sD = output.stride(2);
int64_t out_sH = output.stride(3);
int64_t out_sW = output.stride(4);
scalar_t *inp_ptr = input.data_ptr<scalar_t>();
const scalar_t *inp_ptr = input.const_data_ptr<scalar_t>();
scalar_t *out_ptr = output.data_ptr<scalar_t>();
scalar_t *grid_ptr = grid.data_ptr<scalar_t>();
const scalar_t *grid_ptr = grid.const_data_ptr<scalar_t>();
// loop over each output pixel
at::parallel_for(0, N, 0, [&](int64_t start, int64_t end) {
for (const auto n : c10::irange(start, end)) {
scalar_t *grid_ptr_N = grid_ptr + n * grid_sN;
scalar_t *inp_ptr_N = inp_ptr + n * inp_sN;
const scalar_t *grid_ptr_N = grid_ptr + n * grid_sN;
const scalar_t *inp_ptr_N = inp_ptr + n * inp_sN;
for (const auto d : c10::irange(out_D)) {
for (const auto h : c10::irange(out_H)) {
for (const auto w : c10::irange(out_W)) {
// get the corresponding input x, y, z co-ordinates from grid
scalar_t *grid_ptr_NDHW = grid_ptr_N + d * grid_sD + h * grid_sH + w * grid_sW;
const scalar_t *grid_ptr_NDHW = grid_ptr_N + d * grid_sD + h * grid_sH + w * grid_sW;
scalar_t ix = *grid_ptr_NDHW;
scalar_t iy = grid_ptr_NDHW[grid_sCoor];
scalar_t iz = grid_ptr_NDHW[2 * grid_sCoor];
Expand Down Expand Up @@ -144,7 +144,7 @@ namespace {

// calculate bilinear weighted pixel value and set output pixel
scalar_t *out_ptr_NCDHW = out_ptr + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
scalar_t *inp_ptr_NC = inp_ptr_N;
const scalar_t *inp_ptr_NC = inp_ptr_N;
for (int64_t c = 0; c < C; ++c, out_ptr_NCDHW += out_sC, inp_ptr_NC += inp_sC) {
// (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne
// + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse
Expand Down Expand Up @@ -183,7 +183,7 @@ namespace {

// assign nearest neighbour pixel value to output pixel
scalar_t *out_ptr_NCDHW = out_ptr + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
scalar_t *inp_ptr_NC = inp_ptr_N;
const scalar_t *inp_ptr_NC = inp_ptr_N;
for (int64_t c = 0; c < C; ++c, out_ptr_NCDHW += out_sC, inp_ptr_NC += inp_sC) {
if (within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW = inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + ix_nearest * inp_sW];
Expand Down Expand Up @@ -589,18 +589,18 @@ Tensor _grid_sampler_2d_cpu_fallback(const Tensor& input, const Tensor& grid,
int64_t out_sC = output.stride(1);
int64_t out_sH = output.stride(2);
int64_t out_sW = output.stride(3);
scalar_t *inp_ptr = input.data_ptr<scalar_t>();
const scalar_t *inp_ptr = input.const_data_ptr<scalar_t>();
scalar_t *out_ptr = output.data_ptr<scalar_t>();
scalar_t *grid_ptr = grid.data_ptr<scalar_t>();
const scalar_t *grid_ptr = grid.const_data_ptr<scalar_t>();
// loop over each output pixel
at::parallel_for(0, N, 0, [&](int64_t start, int64_t end) {
for (const auto n : c10::irange(start, end)) {
scalar_t *grid_ptr_N = grid_ptr + n * grid_sN;
scalar_t *inp_ptr_N = inp_ptr + n * inp_sN;
const scalar_t *grid_ptr_N = grid_ptr + n * grid_sN;
const scalar_t *inp_ptr_N = inp_ptr + n * inp_sN;
for (const auto h : c10::irange(out_H)) {
for (const auto w : c10::irange(out_W)) {
// get the corresponding input x, y, z co-ordinates from grid
scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW;
const scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW;
scalar_t x = *grid_ptr_NHW;
scalar_t y = grid_ptr_NHW[grid_sCoor];

Expand Down Expand Up @@ -630,7 +630,7 @@ Tensor _grid_sampler_2d_cpu_fallback(const Tensor& input, const Tensor& grid,
scalar_t se = (ix - ix_nw) * (iy - iy_nw);

// calculate bilinear weighted pixel value and set output pixel
scalar_t *inp_ptr_NC = inp_ptr_N;
const scalar_t *inp_ptr_NC = inp_ptr_N;
scalar_t *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW;
for (int64_t c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) {
auto res = static_cast<scalar_t>(0);
Expand All @@ -654,7 +654,7 @@ Tensor _grid_sampler_2d_cpu_fallback(const Tensor& input, const Tensor& grid,

// assign nearest neighbour pixel value to output pixel
scalar_t *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW;
scalar_t *inp_ptr_NC = inp_ptr_N;
const scalar_t *inp_ptr_NC = inp_ptr_N;
for (int64_t c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) {
if (within_bounds_2d(iy_nearest, ix_nearest, inp_H, inp_W)) {
*out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW];
Expand All @@ -676,7 +676,7 @@ Tensor _grid_sampler_2d_cpu_fallback(const Tensor& input, const Tensor& grid,
const scalar_t tx = ix - ix_nw;
const scalar_t ty = iy - iy_nw;

scalar_t *inp_ptr_NC = inp_ptr_N;
const scalar_t *inp_ptr_NC = inp_ptr_N;
scalar_t *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW;
for (int64_t c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) {
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/GridSampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ static inline bool within_bounds_3d(int64_t d, int64_t h, int64_t w, int64_t D,

template<typename scalar_t>
static inline scalar_t get_value_bounded(
scalar_t* data,
const scalar_t* data,
scalar_t x,
scalar_t y,
int64_t W,
Expand Down
12 changes: 6 additions & 6 deletions aten/src/ATen/native/LossCTC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ std::tuple<Tensor, Tensor> ctc_loss_cpu_template(const Tensor& log_probs, const

int64_t batch_size = log_probs.size(1);
auto lpp = log_probs.permute({1,0,2});
auto log_probs_a_global = lpp.accessor<scalar_t, 3>();
auto log_probs_a_global = lpp.accessor<const scalar_t, 3>();
auto log_alpha_a_global = log_alpha.accessor<scalar_t, 3>();
auto targets_data = targets.data_ptr<target_t>();
auto targets_data = targets.const_data_ptr<target_t>();
auto neg_log_likelihood_a = neg_log_likelihood.accessor<scalar_t, 1>();

// alpha calculation for the first row, the three equations for alpha_1 above eq (6)
Expand Down Expand Up @@ -423,8 +423,8 @@ std::tuple<Tensor, Tensor> ctc_loss_tensor(const Tensor& log_probs, const Tensor

Tensor ilc = input_lengths.to(Device(at::kCPU), at::kLong).contiguous();
Tensor tlc = target_lengths.to(Device(at::kCPU), at::kLong).contiguous();
IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel());
IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
IntArrayRef il(ilc.const_data_ptr<int64_t>(), ilc.numel());
IntArrayRef tl(tlc.const_data_ptr<int64_t>(), tlc.numel());

return at::_ctc_loss(log_probs, targets, il, tl, BLANK, zero_infinity);
}
Expand Down Expand Up @@ -537,8 +537,8 @@ Tensor ctc_loss(const Tensor& log_probs, const Tensor& targets, const Tensor& in

Tensor ilc = input_lengths.to(Device(at::kCPU), at::kLong).contiguous();
Tensor tlc = target_lengths.to(Device(at::kCPU), at::kLong).contiguous();
IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel());
IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
IntArrayRef il(ilc.const_data_ptr<int64_t>(), ilc.numel());
IntArrayRef tl(tlc.const_data_ptr<int64_t>(), tlc.numel());
return at::native::ctc_loss(log_probs, targets, il, tl, BLANK, reduction, zero_infinity);
}

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cpu/DistanceOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ struct Dist {

template <typename F>
static void run_parallel_pdist(Tensor& result, const Tensor& self, const scalar_t p) {
const scalar_t * const self_start = self.data_ptr<scalar_t>();
const scalar_t * const self_start = self.const_data_ptr<scalar_t>();
const scalar_t * const self_end = self_start + self.numel();
int64_t n = self.size(0);
int64_t m = self.size(1);
Expand Down
Loading

0 comments on commit 1989271

Please sign in to comment.