Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Nov 15, 2024
1 parent 0843fc4 commit 3c5c68e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 16 deletions.
38 changes: 24 additions & 14 deletions csrc/multidevice/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,18 @@ int64_t numDeviceDims(const TensorView* tv) {
[](IterDomain* id) { return id->isDeviceDim(); });
}

namespace {
int countIntersection(const std::vector<Val*>& a, const std::vector<Val*>& b) {
int count = 0;
for (auto id : a) {
if (std::find(b.begin(), b.end(), id) != b.end()) {
count++;
}
}
return count;
}
} // namespace

bool haveDifferentShardings(
const TensorView* producer,
const TensorView* consumer,
Expand All @@ -147,9 +159,9 @@ bool haveDifferentShardings(
// Create a map between producer's and consumer's IterDomains. We iterate
// over producer's iterdomain and compare sharding type with consumer's
// iterdomain
std::vector<IterDomain*> mapped_p_ids;
std::vector<Val*> mapped_p_ids;
mapped_p_ids.reserve(producer->getLogicalDomain().size());
std::vector<IterDomain*> mapped_c_ids;
std::vector<Val*> mapped_c_ids;
mapped_c_ids.reserve(consumer->getMaybeRootDomain().size());
const std::unordered_map<IterDomain*, IterDomain*>& p2c =
PairwiseLogicalDomainMap(producer, consumer).mapProducerToConsumer();
Expand All @@ -164,31 +176,29 @@ bool haveDifferentShardings(
}

std::unordered_map<ParallelType, IterDomain*> p_parallel_type_to_id;
auto mapped_p_ids_set = DependencyCheck::getAllDependentVals(
{mapped_p_ids.begin(), mapped_p_ids.end()});
for (IterDomain* p_id : producer->getLoopDomain()) {
if (mapped_p_ids_set.count(p_id)) {
if (const ParallelType parallel_type = p_id->getParallelType();
isParallelTypeDeviceDim(parallel_type)) {
if (const ParallelType parallel_type = p_id->getParallelType();
isParallelTypeDeviceDim(parallel_type)) {
auto dependencies = IterVisitor::getInputsTo({p_id}, mapped_p_ids);
if (countIntersection(dependencies, mapped_p_ids) > 0) {
NVF_ERROR(p_parallel_type_to_id.count(parallel_type) == 0);
p_parallel_type_to_id[parallel_type] = p_id;
}
}
}
auto mapped_c_ids_set = DependencyCheck::getAllDependentVals(
{mapped_c_ids.begin(), mapped_c_ids.end()});
std::unordered_map<ParallelType, IterDomain*> c_parallel_type_to_id;
for (IterDomain* c_id : consumer->getLoopDomain()) {
if (mapped_c_ids_set.count(c_id)) {
if (const ParallelType parallel_type = c_id->getParallelType();
isParallelTypeDeviceDim(parallel_type)) {
if (const ParallelType parallel_type = c_id->getParallelType();
isParallelTypeDeviceDim(parallel_type)) {
auto dependencies = IterVisitor::getInputsTo({c_id}, mapped_c_ids);
if (countIntersection(dependencies, mapped_c_ids) > 0) {
NVF_ERROR(c_parallel_type_to_id.count(parallel_type) == 0);
c_parallel_type_to_id[parallel_type] = c_id;
}
}
}

const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT);
const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::PERMISSIVE);
for (const auto parallel_type : kParallelTypeDIDs) {
if (p_parallel_type_to_id.count(parallel_type) !=
c_parallel_type_to_id.count(parallel_type)) {
Expand All @@ -211,7 +221,7 @@ bool isResharding(const Expr* expr) {
}

IdModel id_model({const_cast<Expr*>(expr)}, {}, false, false);
id_model.buildExactGraph();
id_model.buildPermissiveGraph();
// We don't use getTvsWithDifferentSharding because it creates a computeAtMap,
// which is too costly
for (auto* input : ir_utils::filterByType<TensorView>(expr->inputs())) {
Expand Down
4 changes: 2 additions & 2 deletions csrc/preseg_passes/insert_reshardings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ bool shouldReshardAfter(Expr* expr) {

void insertReshardingsBefore(Fusion* fusion) {
IdModel id_model(fusion, false, false, true);
id_model.buildExactGraph();
id_model.buildPermissiveGraph();
// Remove this after we refactor this as a pre-segmenter pass.
FusionGuard fg(fusion);
for (auto expr : fusion->exprs()) {
Expand Down Expand Up @@ -80,7 +80,7 @@ void insertReshardingsBefore(Fusion* fusion) {

void insertReshardingsAfter(Fusion* fusion) {
IdModel id_model(fusion, false, false, true);
id_model.buildExactGraph();
id_model.buildPermissiveGraph();

// Remove this after we refactor this as a pre-segmenter pass.
FusionGuard fg(fusion);
Expand Down

0 comments on commit 3c5c68e

Please sign in to comment.