diff --git a/tests/cpp/test_multidevice_lower_communication.cpp b/tests/cpp/test_multidevice_lower_communication.cpp index d1f06d80e1d..d89f0a3f3a4 100644 --- a/tests/cpp/test_multidevice_lower_communication.cpp +++ b/tests/cpp/test_multidevice_lower_communication.cpp @@ -7,6 +7,7 @@ // clang-format on #include +#include #include #include @@ -16,15 +17,23 @@ namespace nvfuser { +using testing::Each; +using testing::IsTrue; +using testing::Pointer; +using testing::Property; + namespace { void assertIsCompiledToHostIrContainer( const FusionExecutorCache& executor_cache) { FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); - EXPECT_TRUE(runtime->executors().size() == 1); - for (const auto& ea : runtime->executors()) { - EXPECT_TRUE(ea->isA()) - << "failed to compile to a HostIrContainer with Communications"; - } + EXPECT_EQ(runtime->executors().size(), 1); + EXPECT_THAT( + runtime->executors(), + Each(Pointer(Property( + "is a HostIrExecutor", + &ExecutorAbstract::isA, + IsTrue())))) + << "failed to compile to a HostIrContainer with Communications"; } } // namespace