forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
CUDAUnaryOps.cpp
69 lines (57 loc) · 2.07 KB
/
CUDAUnaryOps.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#include <ATen/ATen.h>
#include <ATen/LegacyTHFunctionsCUDA.h>
#include <ATen/NamedTensorUtils.h>
namespace at { namespace native {
Tensor& _clamp__cuda(Tensor& self, optional<Scalar> min, optional<Scalar> max) {
return _clamp_out_cuda(self, self, min, max);
}
Tensor& _clamp_out_cuda(
Tensor& result,
const Tensor& self,
optional<Scalar> min,
optional<Scalar> max) {
if (min && max) {
legacy::cuda::_th_clamp_out(result, self, *min, *max);
} else if (max) {
legacy::cuda::_th_clamp_max_out(result, self, *max);
} else if (min) {
legacy::cuda::_th_clamp_min_out(result, self, *min);
} else {
AT_ERROR("At least one of 'min' or 'max' must not be None");
}
at::namedinference::propagate_names(result, self);
return result;
}
Tensor& _clamp_max__cuda(Tensor& self, Scalar max) {
return legacy::cuda::_th_clamp_max_out(self, self, max);
}
Tensor& _clamp_max_out_cuda(Tensor& result, const Tensor& self, Scalar max) {
legacy::cuda::_th_clamp_max_out(result, self, max);
at::namedinference::propagate_names(result, self);
return result;
}
Tensor& _clamp_min__cuda(Tensor& self, Scalar min) {
return legacy::cuda::_th_clamp_min_out(self, self, min);
}
Tensor& _clamp_min_out_cuda(Tensor& result, const Tensor& self, Scalar min) {
legacy::cuda::_th_clamp_min_out(result, self, min);
at::namedinference::propagate_names(result, self);
return result;
}
// These are just forwarding stubs
#define IMPLEMENT_UNARY_OP_PREQUEL(op) \
Tensor& _##op##__cuda(Tensor& self) { \
return legacy::cuda::_th_##op##_out(self, self); \
} \
Tensor& _##op##_out_cuda(Tensor& result, const Tensor& self) { \
return legacy::cuda::_th_##op##_out(result, self); \
}
IMPLEMENT_UNARY_OP_PREQUEL(atan)
IMPLEMENT_UNARY_OP_PREQUEL(cos)
IMPLEMENT_UNARY_OP_PREQUEL(cosh)
IMPLEMENT_UNARY_OP_PREQUEL(erf)
IMPLEMENT_UNARY_OP_PREQUEL(erfc)
IMPLEMENT_UNARY_OP_PREQUEL(exp)
IMPLEMENT_UNARY_OP_PREQUEL(tan)
IMPLEMENT_UNARY_OP_PREQUEL(tanh)
}}