forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Correlation.cpp
153 lines (131 loc) · 4.75 KB
/
Correlation.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/TensorOperators.h>
#include <ATen/TensorSubclassLikeUtils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/complex.h>
#include <ATen/ops/corrcoef_native.h>
#include <ATen/ops/cov.h>
#include <ATen/ops/cov_native.h>
#include <ATen/ops/imag.h>
#include <ATen/ops/mm.h>
#include <ATen/ops/real.h>
#include <ATen/ops/scalar_tensor.h>
#include <ATen/ops/sqrt.h>
#include <ATen/ops/true_divide.h>
#endif
namespace at::native {
Tensor cov(
const Tensor& self,
int64_t correction,
const std::optional<Tensor>& fweights,
const std::optional<Tensor>& aweights) {
constexpr int64_t OBSERVATIONS_DIM = 1;
TORCH_CHECK(
self.ndimension() <= 2,
"cov(): expected input to have two or fewer dimensions but got an input with ",
self.ndimension(),
" dimensions");
TORCH_CHECK(
self.scalar_type() != kBool,
"cov(): bool dtype is not supported for input");
// View input tensor as 2D (variables, observations)
auto in = self.ndimension() < 2 ? self.view({1, -1}) : self;
const auto num_observations = in.size(OBSERVATIONS_DIM);
// The product of frequencies (fweights) and weights (aweights).
Tensor w;
if (fweights.has_value()) {
w = fweights.value();
TORCH_CHECK(
w.ndimension() <= 1,
"cov(): expected fweights to have one or fewer dimensions but got fweights with ",
w.ndimension(),
" dimensions");
TORCH_CHECK(
at::isIntegralType(w.scalar_type(), false),
"cov(): expected fweights to have integral dtype but got fweights with ",
w.scalar_type(),
" dtype");
TORCH_CHECK(
w.numel() == num_observations,
"cov(): expected fweights to have the same numel as there are observations in the input but got ",
w.numel(),
" != ",
num_observations);
TORCH_CHECK(
num_observations == 0 || at::is_scalar_tensor_true(w.min().ge(0)),
"cov(): fweights cannot be negative");
}
if (aweights.has_value()) {
const auto& aw = aweights.value();
TORCH_CHECK(
aw.ndimension() <= 1,
"cov(): expected aweights to have one or fewer dimensions but got aweights with ",
aw.ndimension(),
" dimensions");
TORCH_CHECK(
at::isFloatingType(aw.scalar_type()),
"cov(): expected aweights to have floating point dtype but got aweights with ",
aw.scalar_type(),
" dtype");
TORCH_CHECK(
aw.numel() == num_observations,
"cov(): expected aweights to have the same numel as there are observations in the input but got ",
aw.numel(),
" != ",
num_observations);
TORCH_CHECK(
num_observations == 0 || at::is_scalar_tensor_true(aw.min().ge(0)),
"cov(): aweights cannot be negative");
w = w.defined() ? w * aw : aw;
}
// Compute a weighted average of the observations
const auto w_sum = w.defined()
? w.sum()
: at::scalar_tensor(num_observations, in.options().dtype(kLong));
TORCH_CHECK(
!w.defined() || at::is_scalar_tensor_true(w_sum.ne(0)),
"cov(): weights sum to zero, can't be normalized");
const auto avg = (w.defined() ? in * w : in).sum(OBSERVATIONS_DIM) / w_sum;
// Compute the normalization factor
Tensor norm_factor;
if (w.defined() && aweights.has_value() && correction != 0) {
norm_factor = w_sum - correction * (w * aweights.value()).sum() / w_sum;
} else {
norm_factor = w_sum - correction;
}
if (at::is_scalar_tensor_true(norm_factor.le(0))) {
TORCH_WARN("cov(): degrees of freedom is <= 0. Correction should be strictly less than the number of observations.");
norm_factor.zero_();
}
// Compute covariance matrix
in = in - avg.unsqueeze(1);
const auto c = at::mm(in, (w.defined() ? in * w : in).t().conj());
return at::true_divide(c, norm_factor).squeeze();
}
Tensor corrcoef(const Tensor& self) {
TORCH_CHECK(
self.ndimension() <= 2,
"corrcoef(): expected input to have two or fewer dimensions but got an input with ",
self.ndimension(),
" dimensions");
auto c = at::cov(self);
if (c.ndimension() == 0) {
// scalar covariance, return nan if c in {nan, inf, 0}, 1 otherwise
return c / c;
}
// normalize covariance
const auto d = c.diagonal();
const auto stddev = at::sqrt(d.is_complex() ? at::real(d) : d);
c = c / stddev.view({-1, 1});
c = c / stddev.view({1, -1});
// due to floating point rounding the values may be not within [-1, 1], so
// to improve the result we clip the values just as NumPy does.
return c.is_complex()
? at::complex(at::real(c).clip(-1, 1), at::imag(c).clip(-1, 1))
: c.clip(-1, 1);
}
} // namespace at::native