forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
LerpKernel.cpp
65 lines (58 loc) · 2.27 KB
/
LerpKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/native/Lerp.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cpu/Loops.h>
namespace at {
namespace native {
namespace {
static void lerp_kernel_scalar(
Tensor& ret,
const Tensor& self,
const Tensor& end,
Scalar weight) {
TORCH_CHECK(self.dtype() == end.dtype(), "expected dtype ", self.dtype(), " for `end` but got dtype ", end.dtype());
auto iter = TensorIterator::binary_op(ret, self, end,
/*check_mem_overlap=*/true);
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(ret.scalar_type(), "lerp_kernel_scalar", [&] {
using value_t = typename c10::scalar_value_type<scalar_t>::type;
scalar_t weight_val = weight.to<scalar_t>();
at::native::cpu_kernel(
iter,
[weight_val](scalar_t self_val, scalar_t end_val) {
return (zabs<scalar_t, value_t>(weight_val) < 0.5)
? self_val + weight_val * (end_val - self_val)
: end_val - (end_val - self_val) * (scalar_t(1) - weight_val);
});
});
}
static void lerp_kernel_tensor(
Tensor& ret,
const Tensor& self,
const Tensor& end,
const Tensor& weights) {
TORCH_CHECK(self.dtype() == end.dtype(), "expected dtype ", self.dtype(), " for `end` but got dtype ", end.dtype());
TORCH_CHECK(self.dtype() == weights.dtype(), "expected dtype ", self.dtype(), " for `weights` but got dtype ", weights.dtype());
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(true)
.add_output(ret)
.add_input(self)
.add_input(end)
.add_input(weights)
.build();
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(ret.scalar_type(), "lerp_kernel_tensor", [&] {
using value_t = typename c10::scalar_value_type<scalar_t>::type;
at::native::cpu_kernel(
iter,
[](scalar_t self_val, scalar_t end_val, scalar_t weight_val) {
return (zabs<scalar_t, value_t>(weight_val) < 0.5)
? self_val + weight_val * (end_val - self_val)
: end_val - (end_val - self_val) * (scalar_t(1) - weight_val);
});
});
}
} // anonymous namespace
REGISTER_DISPATCH(lerp_kernel_scalar_weight, &lerp_kernel_scalar);
REGISTER_DISPATCH(lerp_kernel_tensor_weight, &lerp_kernel_tensor);
} // namespace native
} // namespace at