From caf437634984f0ef55b9493f08b1d6c8123ad2f0 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Mon, 18 Sep 2023 14:25:29 -0700 Subject: [PATCH] [PyTorch] remove branch in isIntrusivePtr (#109273) There is a code comment in ivalue.h that is intended to explain the motivation for this change fully; please request changes if it doesn't. Differential Revision: [D49245910](https://our.internmc.facebook.com/intern/diff/D49245910/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D49245910/)! Pull Request resolved: https://github.com/pytorch/pytorch/pull/109273 Approved by: https://github.com/ezyang ghstack dependencies: #109272 --- aten/src/ATen/core/ivalue.cpp | 9 ++++++++ aten/src/ATen/core/ivalue.h | 40 ++++++++++++++++++++++++++++++++--- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index b0a9c367bdda09..374a787f50333d 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -144,6 +144,14 @@ c10::TypePtr IValue::TagType::get(const IValue& v) { } // switch above is complete but this silences compiler warnings TORCH_INTERNAL_ASSERT(false, "unhandled case in IValue::type()"); + + // This static_assert has to go into some IValue member function; I + // chose this one. It's not in the class body because that's in + // ivalue.h, which is a very high-fanout header file and we want to + // minimize build time. + static_assert( + kNumTags <= 32, + "IValue::isIntrusivePtr needs to be updated because it assumes there are at most 32 tags"); } void IValue::visit(const std::function& visitor) const { @@ -1227,4 +1235,5 @@ TORCH_API intrusive_ptr collectAny( } return ctx->dstFuture; } + } // namespace c10 diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 07e1db8d6ba4b8..6a263c2df7f368 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -1142,6 +1142,10 @@ struct TORCH_API IValue final { #undef DEFINE_TAG }; +#define COUNT_TAG(x) 1 + + static constexpr auto kNumTags = TORCH_FORALL_TAGS(COUNT_TAG) 0; +#undef COUNT_TAG + template < class T, class NullType = c10::detail::intrusive_target_default_null_type> @@ -1191,7 +1195,10 @@ struct TORCH_API IValue final { tag = Tag::None; } - bool isIntrusivePtr() const { + private: + // This is the source of truth for isIntrusivePtr; edit results here + // as needed and isIntrusivePtr will pick them up. + static constexpr bool isIntrusivePtrConstexpr(Tag tag) { switch (tag) { case Tag::None: return false; @@ -1248,11 +1255,38 @@ struct TORCH_API IValue final { case Tag::Enum: return true; } - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - false, "unexpected tag ", static_cast(tag)); return false; } + public: + // Don't edit this just to add results for new tags; edit + // isIntrusivePtrConstexpr above. + bool isIntrusivePtr() const { + // Implementation NOTE: the switch in isIntrusivePtrConstexpr + // above is the previous production implementation of this + // function. We observed that, at least on x86_64, the generated + // instruction sequence was a similar bit vector test to what we + // have manually implemented below, except that there was an extra + // "bounds check" branch confirming, essentially, that `tag < + // kNumTags` and providing a consistent result in that case. We + // don't care about the result if tag is out of bounds, so we'd + // like to eliminate that comparison and branch; manually + // implementing this function as a bit test is the simplest way I + // could find to accomplish that elimination. + static constexpr uint32_t kTruthTableBitVector = +#define TRUTH_TABLE_ENTRY(tag) \ + (uint32_t(isIntrusivePtrConstexpr(Tag::tag)) << uint32_t(Tag::tag)) | + TORCH_FORALL_TAGS(TRUTH_TABLE_ENTRY) +#undef TRUTH_TABLE_ENTRY + 0; + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + uint32_t(tag) >= 0 && uint32_t(tag) < kNumTags, + "unexpected tag ", + static_cast(tag)); + return kTruthTableBitVector & (1 << (uint32_t(tag) % 32)); + } + // Storage and Generator were treated specially when // is_intrusive_ptr was stored as explicit state. This getter // preserves the old behavior for use with WeakIValue for now.