forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
TensorCompare.cu
56 lines (53 loc) · 1.62 KB
/
TensorCompare.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#include <ATen/NativeFunctions.h>
#include <ATen/Dispatch.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
namespace {
template <typename scalar_t>
void where_cuda(
at::Tensor& ret,
const at::Tensor& condition,
const at::Tensor& self,
const at::Tensor& other) {
if (condition.scalar_type() == at::ScalarType::Byte) {
// Yes this name is repetitive, but the CPU version is called
// CPU_tensor_apply4 and we don't have a CPU namespace or directory.
at::cuda::CUDA_tensor_apply4<scalar_t, uint8_t, scalar_t, scalar_t>(
ret,
condition,
self,
other,
[] __device__(
scalar_t & ret_val,
const uint8_t& cond_val,
const scalar_t& self_val,
const scalar_t& other_val) {
ret_val = cond_val ? self_val : other_val;
});
} else {
at::cuda::CUDA_tensor_apply4<scalar_t, bool, scalar_t, scalar_t>(
ret,
condition,
self,
other,
[] __device__(
scalar_t & ret_val,
const bool& cond_val,
const scalar_t& self_val,
const scalar_t& other_val) {
ret_val = cond_val ? self_val : other_val;
});
}
}
} // namespace
namespace at { namespace native {
Tensor _s_where_cuda(
const Tensor& condition,
const Tensor& self,
const Tensor& other) {
Tensor ret = at::empty(self.sizes(), self.options());
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, ret.scalar_type(), "where_cuda", [&] {
where_cuda<scalar_t>(ret, condition, self, other);
});
return ret;
}
}} // namespace at::native