forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Resize.cuh
70 lines (60 loc) · 2.02 KB
/
Resize.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#pragma once
#include <ATen/ATen.h>
#include <THC/THCTensor.hpp>
#include <c10/cuda/CUDAGuard.h>
namespace at { namespace native {
// These functions are called by native::resize_ as well as (legacy) THC resize.
// They are not in THC/THCTensor.cpp because the at namespace is easier
// to benchmark than THC; I can't get gbenchmark to call fns from THTensor.cpp
static inline void maybe_resize_storage_cuda(TensorImpl* self, int64_t new_size) {
// It does not make sense to try to resize a storage
// to hold 0 elements, and this can break
// if storage_offset is positive but
// new_size is 0, so just bail in that case
// (same comment is in Resize.h)
if (new_size > 0) {
if (!THTensor_getStoragePtr(self)) {
AT_ERROR("Tensor: invalid null storage");
}
if (new_size + self->storage_offset() > self->storage().numel()) {
THCStorage_resize(
globalContext().getTHCState(),
THTensor_getStoragePtr(self),
new_size + self->storage_offset());
}
}
}
inline TensorImpl* resize_impl_cuda_(
TensorImpl* self,
IntArrayRef size,
c10::optional<IntArrayRef> stride,
bool device_guard = true) {
if (self->sizes() == size && (!stride || self->strides() == stride)) {
return self;
}
// NB: We don't need to hold the device guard when calling from TH
cuda::OptionalCUDAGuard guard;
if (device_guard) {
guard.set_index(self->storage().device().index());
}
int64_t storage_size = 1;
if (stride) {
self->set_sizes_and_strides(size, *stride);
// NB: storage size can be different from numel.
for (size_t dim = 0; dim < size.size(); ++dim) {
// FIXME: Don't rely on storage_size being negative because this
// may not be true for some edge cases.
if (size[dim] == 0) {
storage_size = 0;
break;
}
storage_size += (size[dim] - 1) * stride.value()[dim];
}
} else {
self->set_sizes_contiguous(size);
storage_size = self->numel();
}
maybe_resize_storage_cuda(self, storage_size);
return self;
}
}}