Skip to content

Commit

Permalink
[PyTorch] remove branch in isIntrusivePtr (pytorch#109273)
Browse files Browse the repository at this point in the history
There is a code comment in ivalue.h that is intended to explain the motivation for this change fully; please request changes if it doesn't.

Differential Revision: [D49245910](https://our.internmc.facebook.com/intern/diff/D49245910/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D49245910/)!
Pull Request resolved: pytorch#109273
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#109272
  • Loading branch information
swolchok authored and pytorchmergebot committed Sep 19, 2023
1 parent e29330d commit caf4376
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 3 deletions.
9 changes: 9 additions & 0 deletions aten/src/ATen/core/ivalue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,14 @@ c10::TypePtr IValue::TagType<c10::Type>::get(const IValue& v) {
}
// switch above is complete but this silences compiler warnings
TORCH_INTERNAL_ASSERT(false, "unhandled case in IValue::type()");

// This static_assert has to go into some IValue member function; I
// chose this one. It's not in the class body because that's in
// ivalue.h, which is a very high-fanout header file and we want to
// minimize build time.
static_assert(
kNumTags <= 32,
"IValue::isIntrusivePtr needs to be updated because it assumes there are at most 32 tags");
}

void IValue::visit(const std::function<bool (const IValue &)>& visitor) const {
Expand Down Expand Up @@ -1227,4 +1235,5 @@ TORCH_API intrusive_ptr<ivalue::Future> collectAny(
}
return ctx->dstFuture;
}

} // namespace c10
40 changes: 37 additions & 3 deletions aten/src/ATen/core/ivalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -1142,6 +1142,10 @@ struct TORCH_API IValue final {
#undef DEFINE_TAG
};

#define COUNT_TAG(x) 1 +
static constexpr auto kNumTags = TORCH_FORALL_TAGS(COUNT_TAG) 0;
#undef COUNT_TAG

template <
class T,
class NullType = c10::detail::intrusive_target_default_null_type<T>>
Expand Down Expand Up @@ -1191,7 +1195,10 @@ struct TORCH_API IValue final {
tag = Tag::None;
}

bool isIntrusivePtr() const {
private:
// This is the source of truth for isIntrusivePtr; edit results here
// as needed and isIntrusivePtr will pick them up.
static constexpr bool isIntrusivePtrConstexpr(Tag tag) {
switch (tag) {
case Tag::None:
return false;
Expand Down Expand Up @@ -1248,11 +1255,38 @@ struct TORCH_API IValue final {
case Tag::Enum:
return true;
}
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
false, "unexpected tag ", static_cast<int>(tag));
return false;
}

public:
// Don't edit this just to add results for new tags; edit
// isIntrusivePtrConstexpr above.
bool isIntrusivePtr() const {
// Implementation NOTE: the switch in isIntrusivePtrConstexpr
// above is the previous production implementation of this
// function. We observed that, at least on x86_64, the generated
// instruction sequence was a similar bit vector test to what we
// have manually implemented below, except that there was an extra
// "bounds check" branch confirming, essentially, that `tag <
// kNumTags` and providing a consistent result in that case. We
// don't care about the result if tag is out of bounds, so we'd
// like to eliminate that comparison and branch; manually
// implementing this function as a bit test is the simplest way I
// could find to accomplish that elimination.
static constexpr uint32_t kTruthTableBitVector =
#define TRUTH_TABLE_ENTRY(tag) \
(uint32_t(isIntrusivePtrConstexpr(Tag::tag)) << uint32_t(Tag::tag)) |
TORCH_FORALL_TAGS(TRUTH_TABLE_ENTRY)
#undef TRUTH_TABLE_ENTRY
0;

TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
uint32_t(tag) >= 0 && uint32_t(tag) < kNumTags,
"unexpected tag ",
static_cast<int>(tag));
return kTruthTableBitVector & (1 << (uint32_t(tag) % 32));
}

// Storage and Generator were treated specially when
// is_intrusive_ptr was stored as explicit state. This getter
// preserves the old behavior for use with WeakIValue for now.
Expand Down

0 comments on commit caf4376

Please sign in to comment.