diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 8164a38163c..8e3cb4cf26c 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -125,6 +125,18 @@ int64_t numDeviceDims(const TensorView* tv) { [](IterDomain* id) { return id->isDeviceDim(); }); } +namespace { +int countIntersection(const std::vector& a, const std::vector& b) { + int count = 0; + for (auto id : a) { + if (std::find(b.begin(), b.end(), id) != b.end()) { + count++; + } + } + return count; +} +} // namespace + bool haveDifferentShardings( const TensorView* producer, const TensorView* consumer, @@ -147,9 +159,9 @@ bool haveDifferentShardings( // Create a map between producer's and consumer's IterDomains. We iterate // over producer's iterdomain and compare sharding type with consumer's // iterdomain - std::vector mapped_p_ids; + std::vector mapped_p_ids; mapped_p_ids.reserve(producer->getLogicalDomain().size()); - std::vector mapped_c_ids; + std::vector mapped_c_ids; mapped_c_ids.reserve(consumer->getMaybeRootDomain().size()); const std::unordered_map& p2c = PairwiseLogicalDomainMap(producer, consumer).mapProducerToConsumer(); @@ -164,31 +176,29 @@ bool haveDifferentShardings( } std::unordered_map p_parallel_type_to_id; - auto mapped_p_ids_set = DependencyCheck::getAllDependentVals( - {mapped_p_ids.begin(), mapped_p_ids.end()}); for (IterDomain* p_id : producer->getLoopDomain()) { - if (mapped_p_ids_set.count(p_id)) { - if (const ParallelType parallel_type = p_id->getParallelType(); - isParallelTypeDeviceDim(parallel_type)) { + if (const ParallelType parallel_type = p_id->getParallelType(); + isParallelTypeDeviceDim(parallel_type)) { + auto dependencies = IterVisitor::getInputsTo({p_id}, mapped_p_ids); + if (countIntersection(dependencies, mapped_p_ids) > 0) { NVF_ERROR(p_parallel_type_to_id.count(parallel_type) == 0); p_parallel_type_to_id[parallel_type] = p_id; } } } - auto mapped_c_ids_set = DependencyCheck::getAllDependentVals( - {mapped_c_ids.begin(), mapped_c_ids.end()}); std::unordered_map c_parallel_type_to_id; for (IterDomain* c_id : consumer->getLoopDomain()) { - if (mapped_c_ids_set.count(c_id)) { - if (const ParallelType parallel_type = c_id->getParallelType(); - isParallelTypeDeviceDim(parallel_type)) { + if (const ParallelType parallel_type = c_id->getParallelType(); + isParallelTypeDeviceDim(parallel_type)) { + auto dependencies = IterVisitor::getInputsTo({c_id}, mapped_c_ids); + if (countIntersection(dependencies, mapped_c_ids) > 0) { NVF_ERROR(c_parallel_type_to_id.count(parallel_type) == 0); c_parallel_type_to_id[parallel_type] = c_id; } } } - const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT); + const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::PERMISSIVE); for (const auto parallel_type : kParallelTypeDIDs) { if (p_parallel_type_to_id.count(parallel_type) != c_parallel_type_to_id.count(parallel_type)) { @@ -211,7 +221,7 @@ bool isResharding(const Expr* expr) { } IdModel id_model({const_cast(expr)}, {}, false, false); - id_model.buildExactGraph(); + id_model.buildPermissiveGraph(); // We don't use getTvsWithDifferentSharding because it creates a computeAtMap, // which is too costly for (auto* input : ir_utils::filterByType(expr->inputs())) { diff --git a/csrc/preseg_passes/insert_reshardings.cpp b/csrc/preseg_passes/insert_reshardings.cpp index ec3ee800e50..ca586facb4c 100644 --- a/csrc/preseg_passes/insert_reshardings.cpp +++ b/csrc/preseg_passes/insert_reshardings.cpp @@ -30,7 +30,7 @@ bool shouldReshardAfter(Expr* expr) { void insertReshardingsBefore(Fusion* fusion) { IdModel id_model(fusion, false, false, true); - id_model.buildExactGraph(); + id_model.buildPermissiveGraph(); // Remove this after we refactor this as a pre-segmenter pass. FusionGuard fg(fusion); for (auto expr : fusion->exprs()) { @@ -80,7 +80,7 @@ void insertReshardingsBefore(Fusion* fusion) { void insertReshardingsAfter(Fusion* fusion) { IdModel id_model(fusion, false, false, true); - id_model.buildExactGraph(); + id_model.buildPermissiveGraph(); // Remove this after we refactor this as a pre-segmenter pass. FusionGuard fg(fusion);