forked from chuanqi129/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Cross.cpp
82 lines (66 loc) · 2.76 KB
/
Cross.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/Cross.h>
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/TensorMeta.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/ExpandUtils.h>
#include <ATen/native/Resize.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/cross_native.h>
#include <ATen/ops/linalg_cross.h>
#include <ATen/ops/linalg_cross_native.h>
#endif
namespace at {
namespace meta {
TORCH_META_FUNC(linalg_cross)
(const Tensor & input, const Tensor & other, int64_t dim) {
auto x_d = input.dim();
auto y_d = other.dim();
// This is to avoid things like
// linalg.cross(torch.randn(2, 3), torch.randn(5, 2, 3), dim=2)
TORCH_CHECK(x_d == y_d, "linalg.cross: inputs must have the same number of dimensions.");
TORCH_CHECK(input.size(dim) == 3 && other.size(dim) == 3, "linalg.cross: inputs dimension ", dim, " must have length 3. Got ", input.size(dim), " and ", other.size(dim));
// Broadcast the batch dimension of input and other.
// Since the non-batch dimensions agree, this is the same as broadcast all the inputs
auto out_size = infer_size(input.sizes(), other.sizes());
set_output_raw_strided(0, out_size, {}, input.options());
}
}
namespace native {
DEFINE_DISPATCH(cross_stub);
static int64_t _default_cross_dim(const c10::optional<int64_t> &dimension, SymIntArrayRef sizes) {
// If dimension is not given, it defaults to the first dimension found with the size 3.
// Note that this behaviour might be unexpected.
// _default_cross_dim is called internally inside the cross implementation to calculate
// the dim and finally cross delegates to the linalg_cross implementation with this dim
if(dimension.has_value()) {
return *dimension;
}
for(auto i : c10::irange(sizes.size())) {
if(sizes[i] == 3) {
return i;
}
}
TORCH_CHECK(false, "no dimension of size 3 in input");
}
Tensor cross(const Tensor & input, const Tensor & other, const c10::optional<int64_t> dimension) {
auto dim = _default_cross_dim(dimension, input.sym_sizes());
return at::linalg_cross(input, other, dim);
}
Tensor & cross_out(const Tensor & input, const Tensor & other, const c10::optional<int64_t> dimension, Tensor & out) {
auto dim = _default_cross_dim(dimension, input.sym_sizes());
return at::linalg_cross_out(out, input, other, dim);
}
TORCH_IMPL_FUNC(linalg_cross_out)
(const Tensor & input, const Tensor & other, int64_t dim, const Tensor & out) {
dim = maybe_wrap_dim(dim, input.dim());
auto out_size = out.sizes();
Tensor input_broadcasted = input.expand(out_size);
Tensor other_broadcasted = other.expand(out_size);
cross_stub(input.device().type(), out, input_broadcasted, other_broadcasted, dim);
}
}} // namespace at::native