forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Constraints.cpp
90 lines (79 loc) · 2.56 KB
/
Constraints.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include <limits>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <c10/core/Device.h>
#include <c10/core/Layout.h>
#include <c10/core/MemoryFormat.h>
#include <c10/core/Scalar.h>
#include <c10/core/ScalarType.h>
#include <optional>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_functional_sym_constrain_range_native.h>
#include <ATen/ops/_make_dep_token_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/sym_constrain_range_native.h>
#include <ATen/ops/sym_constrain_range_for_size_native.h>
#include <ATen/ops/_functional_sym_constrain_range_for_size_native.h>
#endif
namespace at::native {
void sym_constrain_range(
const Scalar& size,
std::optional<int64_t> min,
std::optional<int64_t> max) {
int64_t min_val = min.has_value() ? min.value() : std::numeric_limits<int64_t>::min();
int64_t max_val = max.has_value() ? max.value() : std::numeric_limits<int64_t>::max();
int64_t size_as_int = size.toLong();
TORCH_CHECK(
max_val >= min_val,
"Max must be greater than or equal to min. Got min=",
min_val,
" max=",
max_val
);
TORCH_CHECK(
min_val <= size_as_int && size_as_int <= max_val,
"Invalid value range for ",
size_as_int,
" between [",
min_val,
", ",
max_val,
"]."
);
}
Tensor _functional_sym_constrain_range(
const Scalar& size,
std::optional<int64_t> min,
std::optional<int64_t> max,
const Tensor& dep_token) {
sym_constrain_range(size, min, max);
return dep_token.clone();
}
void sym_constrain_range_for_size(const Scalar& size, std::optional<int64_t> min, std::optional<int64_t> max) {
int64_t min_val = min.has_value() ? min.value() : 0;
if (max.has_value() && max.value() <= 2) {
TORCH_CHECK(false, "Max value to constrain_range_for_size must be greater than 2. got: ", max.value());
}
sym_constrain_range(size, min_val, max);
}
Tensor _functional_sym_constrain_range_for_size(
const Scalar& size,
std::optional<int64_t> min,
std::optional<int64_t> max,
const Tensor& dep_token) {
sym_constrain_range_for_size(size, min, max);
return dep_token.clone();
}
Tensor _make_dep_token_cpu(
std::optional<ScalarType> dtype_opt,
std::optional<Layout> layout_opt,
std::optional<Device> device_opt,
std::optional<bool> pin_memory_opt,
std::optional<c10::MemoryFormat> memory_format_opt) {
return at::empty(
{}, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
}
} // namespace at::native