forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
NestedIntSymNodeImpl.cpp
80 lines (69 loc) · 2.77 KB
/
NestedIntSymNodeImpl.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include <ATen/core/NestedIntSymNodeImpl.h>
#include <c10/core/SymNodeImpl.h>
#include <c10/util/Exception.h>
namespace c10 {
namespace {
bool _eq(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
TORCH_INTERNAL_ASSERT(lhs->is_nested_int());
std::optional<int64_t> c = rhs->nested_int();
return (
c.has_value() && lhs->nested_int() == *c &&
lhs->nested_int_coeff() == rhs->nested_int_coeff());
}
bool _ge(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
if (auto mb_si = lhs->nested_int()) {
if (auto mb_si2 = rhs->nested_int()) {
if (*mb_si == *mb_si2) {
return lhs->nested_int_coeff() >= rhs->nested_int_coeff();
}
TORCH_CHECK(false, "nested int ", op, ": Relation is indeterminate");
}
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
if (rhs->constant_int() && *rhs->constant_int() <= 2) {
return true;
}
TORCH_CHECK(false, "nested int ", op, ": Relation is indeterminate");
} else if (rhs->nested_int()) {
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
if (lhs->constant_int() && *lhs->constant_int() < 2) {
return false;
}
TORCH_CHECK(false, "nested int ", op, ": Relation is indeterminate");
}
TORCH_INTERNAL_ASSERT(false, "expect at least one nested int");
}
} // namespace
c10::SymNode NestedIntSymNodeImpl::eq(const c10::SymNode& other) {
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
_eq("eq", this, other.get())));
}
c10::SymNode NestedIntSymNodeImpl::ne(const c10::SymNode& other) {
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
!_eq("ne", this, other.get())));
}
c10::SymNode NestedIntSymNodeImpl::ge(const c10::SymNode& other) {
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
_ge("ge", this, other.get())));
}
c10::SymNode NestedIntSymNodeImpl::gt(const c10::SymNode& other) {
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
!_ge("gt", other.get(), this)));
}
c10::SymNode NestedIntSymNodeImpl::lt(const c10::SymNode& other) {
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
!_ge("lt", this, other.get())));
}
c10::SymNode NestedIntSymNodeImpl::le(const c10::SymNode& other) {
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
_ge("le", other.get(), this)));
}
c10::SymNode NestedIntSymNodeImpl::mul(const c10::SymNode& other) {
TORCH_CHECK(!other->nested_int(), "nested int cannot be multiplied by nested int");
std::optional<int64_t> c = other->constant_int();
TORCH_CHECK(c.has_value());
return SymNode(c10::make_intrusive<NestedIntSymNodeImpl>(val_, coeff_ * *c));
}
c10::SymNode NestedIntSymNodeImpl::clone() {
return SymNode(c10::make_intrusive<NestedIntSymNodeImpl>(val_, coeff_));
}
} // namespace c10