forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
zmath_std.h
157 lines (125 loc) · 4.03 KB
/
zmath_std.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#pragma once
// Complex number math operations that act as no-ops for other dtypes.
#include <complex>
#include <c10/util/math_compat.h>
#include<ATen/NumericUtils.h>
namespace at { namespace native {
namespace {
template <typename TYPE>
struct ztype {
using value_t = TYPE;
};
template <>
struct ztype<std::complex<double>> {
using value_t = double;
};
template <>
struct ztype<std::complex<float>> {
using value_t = float;
};
template<>
inline std::complex<float> zabs <std::complex<float>> (std::complex<float> z) {
return std::complex<float>(std::abs(z));
}
template<>
inline float zabs <std::complex<float>, float> (std::complex<float> z) {
return std::abs(z);
}
template<>
inline std::complex<double> zabs <std::complex<double>> (std::complex<double> z) {
return std::complex<double>(std::abs(z));
}
template<>
inline double zabs <std::complex<double>, double> (std::complex<double> z) {
return std::abs(z);
}
template<>
inline std::complex<float> angle_impl <std::complex<float>> (std::complex<float> z) {
return std::complex<float>(std::arg(z), 0.0);
}
template<>
inline float angle_impl <std::complex<float>, float> (std::complex<float> z) {
return std::arg(z);
}
template<>
inline std::complex<double> angle_impl <std::complex<double>> (std::complex<double> z) {
return std::complex<double>(std::arg(z), 0.0);
}
template<>
inline double angle_impl <std::complex<double>, double> (std::complex<double> z) {
return std::arg(z);
}
template<>
constexpr std::complex<float> real_impl <std::complex<float>> (std::complex<float> z) {
return std::complex<float>(z.real(), 0.0);
}
template<>
constexpr float real_impl <std::complex<float>, float> (std::complex<float> z) {
return z.real();
}
template<>
constexpr std::complex<double> real_impl <std::complex<double>> (std::complex<double> z) {
return std::complex<double>(z.real(), 0.0);
}
template<>
constexpr double real_impl <std::complex<double>, double> (std::complex<double> z) {
return z.real();
}
template<>
constexpr std::complex<float> imag_impl <std::complex<float>> (std::complex<float> z) {
return std::complex<float>(z.imag(), 0.0);
}
template<>
constexpr float imag_impl <std::complex<float>, float> (std::complex<float> z) {
return z.imag();
}
template<>
constexpr std::complex<double> imag_impl <std::complex<double>> (std::complex<double> z) {
return std::complex<double>(z.imag(), 0.0);
}
template<>
constexpr double imag_impl <std::complex<double>, double> (std::complex<double> z) {
return z.imag();
}
template<>
inline std::complex<float> conj_impl <std::complex<float>> (std::complex<float> z) {
return std::complex<float>(z.real(), -z.imag());
}
template<>
inline std::complex<double> conj_impl <std::complex<double>> (std::complex<double> z) {
return std::complex<double>(z.real(), -z.imag());
}
template <>
inline std::complex<float> ceil_impl (std::complex<float> z) {
return std::complex<float>(std::ceil(z.real()), std::ceil(z.imag()));
}
template <>
inline std::complex<double> ceil_impl (std::complex<double> z) {
return std::complex<double>(std::ceil(z.real()), std::ceil(z.imag()));
}
template <>
inline std::complex<float> floor_impl (std::complex<float> z) {
return std::complex<float>(std::floor(z.real()), std::floor(z.imag()));
}
template <>
inline std::complex<double> floor_impl (std::complex<double> z) {
return std::complex<double>(std::floor(z.real()), std::floor(z.imag()));
}
template <>
inline std::complex<float> round_impl (std::complex<float> z) {
return std::complex<float>(std::nearbyint(z.real()), std::nearbyint(z.imag()));
}
template <>
inline std::complex<double> round_impl (std::complex<double> z) {
return std::complex<double>(std::nearbyint(z.real()), std::nearbyint(z.imag()));
}
template <>
inline std::complex<float> trunc_impl (std::complex<float> z) {
return std::complex<float>(std::trunc(z.real()), std::trunc(z.imag()));
}
template <>
inline std::complex<double> trunc_impl (std::complex<double> z) {
return std::complex<double>(std::trunc(z.real()), std::trunc(z.imag()));
}
} // end namespace
}} //end at::native