From db6d4d772f6c9ea4a04a558172b5685f725b128f Mon Sep 17 00:00:00 2001 From: Bo Zhang <105368690+zhangbopd@users.noreply.github.com> Date: Mon, 9 Oct 2023 17:51:12 +0800 Subject: [PATCH] Standard naming Part 2 (#57939) * split SymbolicDimMgr ShapeComputationIRAnalysis --- .../shape/transforms/shape_optimization.cc | 169 ++++ .../shape/transforms/shape_optimization.h | 52 ++ .../transforms/shape_optimization_pass.cc | 1 + .../shape/utils/shape_optimization_utils.cc | 606 +++++++++++++++ .../shape/utils/shape_optimization_utils.h | 94 +++ paddle/pir/dialect/shape/utils/shape_utils.cc | 734 ------------------ paddle/pir/dialect/shape/utils/shape_utils.h | 129 +-- .../pir/shape_dialect/constraint_pass_test.cc | 1 + 8 files changed, 924 insertions(+), 862 deletions(-) create mode 100644 paddle/pir/dialect/shape/transforms/shape_optimization.cc create mode 100644 paddle/pir/dialect/shape/transforms/shape_optimization.h diff --git a/paddle/pir/dialect/shape/transforms/shape_optimization.cc b/paddle/pir/dialect/shape/transforms/shape_optimization.cc new file mode 100644 index 0000000000000..959d098675b29 --- /dev/null +++ b/paddle/pir/dialect/shape/transforms/shape_optimization.cc @@ -0,0 +1,169 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pir/dialect/shape/transforms/shape_optimization.h" +#include "paddle/pir/dialect/shape/utils/shape_utils.h" + +namespace pir { + +ShapeComputationIRAnalysis::ShapeComputationIRAnalysis(ModuleOp m, + SymbolicDimMgr& mgr) + : m_(m), mgr_(mgr) {} + +bool ShapeComputationIRAnalysis::Run() { + // Make sure only run once. + if (initialized_) return false; + initialized_ = true; + auto buildShapeFunc = + std::bind(&ShapeComputationIRAnalysis::BuildShapeOnOperation, + this, + std::placeholders::_1); + if (!RunOnRegion(&(m_->region(0)), buildShapeFunc)) return false; + auto applyOpConstraintFunc = + std::bind(&ShapeComputationIRAnalysis::ApplyOpConstraint, + this, + std::placeholders::_1); + if (!RunOnRegion(&(m_->region(0)), applyOpConstraintFunc)) return false; + return true; +} + +bool ShapeComputationIRAnalysis::RunOnRegion(Region* region, func fn) { + for (Block* block : *region) { + if (!RunOnBlock(block, fn)) return false; + } + return true; +} + +bool ShapeComputationIRAnalysis::RunOnBlock(Block* block, func fn) { + // TODO(liujinnan): mapping block arguments + + std::vector op_list; + for (Operation* op : *block) op_list.push_back(op); + for (Operation* op : op_list) { + if (!RunOnOperation(op, fn)) return false; + } + return true; +} + +bool ShapeComputationIRAnalysis::RunOnOperation(Operation* op, func fn) { + for (size_t i = 0; i < op->num_regions(); ++i) { + if (!RunOnRegion(&(op->region(i)), fn)) return false; + } + return fn(op); +} + +bool ShapeComputationIRAnalysis::BuildShapeOnOperation(Operation* op) { + if (op->isa()) return true; + if (op->isa()) { + Value value = op->operand_source(0); + std::vector symbols; + if (op->HasAttribute(SymbolicDim::GetSymbolicDimAttrName())) { + auto attrs = + op->attribute(SymbolicDim::GetSymbolicDimAttrName()) + .AsVector(); + for (Attribute attr : attrs) { + auto sym = mgr_.symbolTable().Lookup( + attr.dyn_cast().AsString()); + assert(sym); + SymbolicDim root = mgr_.GetRootSymbolicDim(sym); + symbols.push_back(root); + } + } else { + symbols = mgr_.CreateSymbolicDimsForRankedValue(value); + std::vector attrs; + for (SymbolicDim sym : symbols) { + Attribute rootSymbol = + StrAttribute::get(m_->ir_context(), sym.GetSymName()); + attrs.push_back(rootSymbol); + } + op->set_attribute(SymbolicDim::GetSymbolicDimAttrName(), + ArrayAttribute::get(m_->ir_context(), attrs)); + } + rankedTensor2SymDims_[value] = std::move(symbols); + return true; + } + for (size_t i = 0; i < op->num_results(); ++i) { + if (!BuildShapeOnValue(op->result(i))) return false; + } + return true; +} + +bool ShapeComputationIRAnalysis::BuildShapeOnValue(Value value) { + Type type = value.type(); + if (IsIntOrIndex(type)) { + SymbolicDim sym = mgr_.NewSymbolicDim(); + value2SymDim_[value] = sym; + } else if (IsCandidateShapeTensorType(type)) { + auto shapedTy = type.dyn_cast(); + std::vector symbols; + for (size_t i = 0, d = shapedTy.GetShape()[0]; i < d; ++i) + symbols.push_back(mgr_.NewSymbolicDim()); + shapeTensor2SymDims_[value] = std::move(symbols); + } + return true; +} + +bool ShapeComputationIRAnalysis::ApplyOpConstraint(Operation* op) { + IR_ENFORCE(ApplyIndexOpConstraint(op), + "Fail to apply constraint for index op"); + IR_ENFORCE(ApplyTieShapeOpConstraint(op), + "Fail to apply constraint for tie_shape op"); + + // TODO(zhangbo63): add more constraints + return true; +} + +bool ShapeComputationIRAnalysis::ApplyIndexOpConstraint(Operation* op) { + if (op->num_results() == 0) return true; + + Type type = op->result(0).type(); + if (!IsIntOrIndex(type)) return true; + + if (auto dimOp = op->dyn_cast()) { + int64_t dimIndex = dimOp.index() + .dyn_cast() + .owner() + ->attribute("value") + .data(); + value2SymDim_[dimOp.out()].UpdateKnownNonNegative(true); + if (!mgr_.MapSymbolicDimEqual( + value2SymDim_[dimOp.out()], + rankedTensor2SymDims_[dimOp.source()][dimIndex])) { + return false; + } + + } else if (auto constOp = op->dyn_cast()) { + int64_t val = constOp.value().dyn_cast().data(); + if (!mgr_.MapSymbolicDimEqual(value2SymDim_[op->result(0)], + mgr_.NewConstantSymbolicDim(val))) { + return false; + } + } + // TODO(zhangbo63): add support for reifyInferShape. (e.g. mul/add) + return true; +} + +bool ShapeComputationIRAnalysis::ApplyTieShapeOpConstraint(Operation* op) { + if (auto tieShape = op->dyn_cast()) { + auto& value = rankedTensor2SymDims_[op->operand_source(0)]; + for (size_t idx = 0; idx < tieShape.dims().size(); ++idx) { + if (!mgr_.MapSymbolicDimEqual(value2SymDim_[tieShape.dims()[idx]], + value[idx])) + return false; + mgr_.GetRootSymbolicDim(value[idx]).UpdateKnownNonNegative(true); + } + } + return true; +} +} // namespace pir diff --git a/paddle/pir/dialect/shape/transforms/shape_optimization.h b/paddle/pir/dialect/shape/transforms/shape_optimization.h new file mode 100644 index 0000000000000..ba711f288a770 --- /dev/null +++ b/paddle/pir/dialect/shape/transforms/shape_optimization.h @@ -0,0 +1,52 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/dialect/shape/utils/shape_optimization_utils.h" +#include "paddle/pir/dialect/shape/utils/symbol_table.h" + +namespace pir { +class ShapeComputationIRAnalysis { + public: + using func = std::function; + explicit ShapeComputationIRAnalysis(ModuleOp m, + SymbolicDimMgr& mgr); // NOLINT + bool Run(); + + private: + bool RunOnRegion(Region* region, func fn); + bool RunOnBlock(Block* block, func fn); + bool RunOnOperation(Operation* op, func fn); + + bool BuildShapeOnOperation(Operation* op); + bool BuildShapeOnValue(Value value); + + bool ApplyOpConstraint(Operation* op); + bool ApplyIndexOpConstraint(Operation* op); + bool ApplyTieShapeOpConstraint(Operation* op); + + bool initialized_ = false; + ModuleOp m_; + SymbolicDimMgr& mgr_; + + std::unordered_map value2SymDim_; + + // shape tensor is the 1D ranked tensor with int/index dtype. + std::unordered_map> shapeTensor2SymDims_; + + std::unordered_map> rankedTensor2SymDims_; +}; + +} // namespace pir diff --git a/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc b/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc index 6bbb918ebc1f1..f9316f3682aa3 100644 --- a/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc +++ b/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc @@ -18,6 +18,7 @@ #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/program.h" +#include "paddle/pir/dialect/shape/transforms/shape_optimization.h" #include "paddle/pir/dialect/shape/utils/shape_utils.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_manager.h" diff --git a/paddle/pir/dialect/shape/utils/shape_optimization_utils.cc b/paddle/pir/dialect/shape/utils/shape_optimization_utils.cc index 35776be4f5325..07f7cf4129a4d 100644 --- a/paddle/pir/dialect/shape/utils/shape_optimization_utils.cc +++ b/paddle/pir/dialect/shape/utils/shape_optimization_utils.cc @@ -13,3 +13,609 @@ // limitations under the License. #include "paddle/pir/dialect/shape/utils/shape_optimization_utils.h" +#include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/dialect/shape/utils/symbol_table.h" + +namespace pir { + +bool CompareSymbolicDimNames(const std::string& lhs, const std::string& rhs) { + // S -> Symbol : unknown dimension size at compile time + // C -> Constant : constant dimension size at compile time + if (lhs.size() < 1 || (lhs[0] != 'S' && lhs[0] != 'C')) return lhs < rhs; + if (rhs.size() < 1 || (rhs[0] != 'S' && rhs[0] != 'C')) return lhs < rhs; + int64_t lhs_idx = 0, rhs_idx = 0; + try { + lhs_idx = stol(lhs.substr(1)); + rhs_idx = stol(rhs.substr(1)); + } catch (const std::exception& e) { + IR_THROW("Invalid symbolic name"); + } + return (lhs[0] < rhs[0]) || (lhs[0] == rhs[0] && lhs_idx < rhs_idx); +} + +// Gives a consistent order of a list op SymbolicDimProducts +bool CompareSymbolicDimProduct(SymbolicDimProduct& lhs, // NOLINT + SymbolicDimProduct& rhs) { // NOLINT + if (lhs.symbols.size() < rhs.symbols.size()) return true; + if (lhs.symbols.size() == rhs.symbols.size()) { + for (size_t idx = 0; idx < lhs.symbols.size(); ++idx) { + const std::string lhs_name = lhs.symbols[idx].GetSymName(); + const std::string rhs_name = rhs.symbols[idx].GetSymName(); + if (CompareSymbolicDimNames(lhs_name, rhs_name)) return true; + if (lhs_name != rhs_name) return false; + } + } + return false; +} + +SymbolicDimMgr::SymbolicDimMgr(ModuleOp m) : m_(m) { + for (auto op : *(m.block())) { + if (op->isa()) { + symbol_table_ = SymbolTable(op); + return; + } + } + Builder builder = Builder(m_.ir_context(), m_.block(), m_.block()->begin()); + dialect::FuncOp func = builder.Build(); + symbol_table_ = SymbolTable(func); +} + +bool SymbolicDimMgr::Load() { + auto func_op = symbol_table_.getOp()->dyn_cast(); + assert(func_op); + for (auto op : *(func_op.block())) { + symbol_table_.insert(op); + if (SymbolicDim sym_dim_op = op->dyn_cast()) { + symbol_dim_union_set_[sym_dim_op] = sym_dim_op; + symbol_name_set_.insert(sym_dim_op.GetSymName()); + } + } + return LoadShapeConstraintGraph(); +} + +bool SymbolicDimMgr::LoadShapeConstraintGraph() { + // TODO(liujinnan): add more constraint function. currently, only support + // tie_product_equal. + auto constraint_vec = + symbol_table_.Lookup("tie_product_equal"); + + if (!constraint_vec.size()) return true; + + auto build_sym_product = [&](std::vector range, + SymbolicDimProduct& product) { + for (Value v : range) { + auto defining_op = v.dyn_cast().owner(); + if (auto constOp = defining_op->dyn_cast()) { + product.factor *= constOp.value().dyn_cast().data(); + continue; + } else if (auto dimOp = defining_op->dyn_cast()) { + auto sym = symbol_table_.Lookup(dimOp.getName()); + if (!sym) return false; + product.symbols.push_back(sym); + continue; + } + return false; + } + return true; + }; + + for (auto op : constraint_vec) { + SymbolicDimProduct lhs, rhs; + if (!build_sym_product(op.lhs(), lhs) || + !build_sym_product(op.rhs(), rhs) || + !MapSymbolicDimProductEqual(lhs, rhs)) + return false; + } + return true; +} + +bool SymbolicDimMgr::MapSymbolicDimProductEqual(const SymbolicDimProduct& lhs, + const SymbolicDimProduct& rhs) { + SymbolicDimProduct new_lhs, new_rhs; + std::tie(new_lhs, new_rhs) = SimplifySymbolicDimProductPair(lhs, rhs); + + // Return true for identity case. + if (new_lhs == new_rhs) return true; + + if (new_lhs.factor == new_rhs.factor && new_lhs.symbols.size() == 1 && + new_rhs.symbols.size() == 1) { + return MapSymbolicDimEqual(new_lhs.symbols[0], new_rhs.symbols[0]); + } else if (new_lhs.symbols.size() == 0 && new_rhs.symbols.size() == 1 && + new_rhs.factor == 1) { + return MapSymbolicDimEqual(NewConstantSymbolicDim(new_lhs.factor), + new_rhs.symbols[0]); + } else if (new_rhs.symbols.size() == 0 && new_lhs.symbols.size() == 1 && + new_lhs.factor == 1) { + return MapSymbolicDimEqual(NewConstantSymbolicDim(new_rhs.factor), + new_lhs.symbols[0]); + } + + product_equality_map_[new_lhs][new_rhs] = + product_equality_map_[new_rhs][new_lhs] = true; + + product_equality_map_updated_ = false; + return true; +} + +SymbolicDimProduct SymbolicDimMgr::SimplifySymbolicDimProduct( + const SymbolicDimProduct& x) { + std::vector copied; + copied.reserve(x.symbols.size()); + for (SymbolicDim op : x.symbols) copied.push_back(GetRootSymbolicDim(op)); + + std::sort( + copied.begin(), copied.end(), [&](SymbolicDim lhs, SymbolicDim rhs) { + return CompareSymbolicDimNames(lhs.GetSymName(), rhs.GetSymName()); + }); + SymbolicDimProduct new_x; + new_x.factor = x.factor; + for (SymbolicDim op : copied) { + if (!op.IsDynamic()) { + new_x.factor *= op.GetDimSize(); + } else { + new_x.symbols.push_back(op); + } + } + return new_x; +} + +std::pair +SymbolicDimMgr::SimplifySymbolicDimProductPair(const SymbolicDimProduct& x, + const SymbolicDimProduct& y) { + // First do some basic clean up (e.g. folding const symbolic dim op into the + // fator field) + auto lhs = SimplifySymbolicDimProduct(x); + auto rhs = SimplifySymbolicDimProduct(y); + + SymbolicDimProduct new_lhs, new_rhs; + int64_t gcd_factor = std::gcd(std::abs(lhs.factor), std::abs(rhs.factor)); + + // 0 * lhs_symbols = 0 * rhs_symbols, no more information. + // Just return empty new_lhs & new_rhs + if (!gcd_factor) + return std::make_pair(std::move(new_lhs), std::move(new_rhs)); + + // Canonicalization factor form: always let the smaller factor being positive + // number. + if (std::abs(lhs.factor) < std::abs(rhs.factor)) { + if (lhs.factor < 0) gcd_factor = -gcd_factor; + } else { + if (rhs.factor < 0) gcd_factor = -gcd_factor; + } + + new_lhs.factor = lhs.factor / gcd_factor; + new_rhs.factor = rhs.factor / gcd_factor; + + std::unordered_map lhs_symbol_map; + std::unordered_map rhs_symbol_map; + + for (SymbolicDim op : lhs.symbols) ++lhs_symbol_map[op]; + for (SymbolicDim op : rhs.symbols) ++rhs_symbol_map[op]; + + for (SymbolicDim op : lhs.symbols) { + auto it = rhs_symbol_map.find(op); + if (it != rhs_symbol_map.end() && op.GetKnownNonSizeZero()) { + if (--it->second == 0) rhs_symbol_map.erase(it); + continue; + } + new_lhs.symbols.push_back(op); + } + + for (SymbolicDim op : rhs.symbols) { + auto it = lhs_symbol_map.find(op); + if (it != lhs_symbol_map.end() && op.GetKnownNonSizeZero()) { + if (--it->second == 0) lhs_symbol_map.erase(it); + continue; + } + new_rhs.symbols.push_back(op); + } + + if (!new_lhs.factor) new_lhs.symbols.clear(); + if (!new_rhs.factor) new_rhs.symbols.clear(); + + return std::make_pair(std::move(new_lhs), std::move(new_rhs)); +} + +const std::string SymbolicDimMgr::GetNextName() { + std::string name; + do { + name = "S" + std::to_string(next_symbolic_idx_++); + } while (!symbol_name_set_.insert(name).second); + return name; +} + +SymbolicDim SymbolicDimMgr::NewSymbolicDim(const std::string& name) { + auto func_op = symbol_table_.getOp()->dyn_cast(); + assert(func_op); + Builder builder = Builder(m_.ir_context(), func_op.block()); + // default settting dim != 0 + dialect::SymbolicDim symbol = + builder.Build(name.empty() ? GetNextName() : name, + ShapedTypeInterface::kDynamic, + false, + false, + false, + true); + symbol_dim_union_set_[symbol] = symbol; + symbol_table_.insert(symbol); + return symbol; +} + +SymbolicDim SymbolicDimMgr::NewConstantSymbolicDim(int64_t val) { + auto it = constant_symbolic_dim_map_.find(val); + if (it == constant_symbolic_dim_map_.end()) { + auto name = "C" + std::to_string(val); + it = constant_symbolic_dim_map_ + .insert(std::make_pair(val, NewSymbolicDim(name))) + .first; + it->second.SetDimSize(val); + if (val == -1) it->second.UpdateKnownNegativeOne(true); + if (val >= 0) it->second.UpdateKnownNonNegative(true); + if (val != 1) it->second.UpdateKnownNonSizeOne(true); + if (val != 0) it->second.UpdateKnownNonSizeZero(true); + } + return GetRootSymbolicDim(it->second); +} + +std::vector SymbolicDimMgr::CreateSymbolicDimsForRankedValue( + Value value) { + std::vector symbols; + auto dims = value.type().dyn_cast().dims(); + for (int idx = 0; idx < dims.size(); ++idx) { + symbols.push_back(dims[idx] == ShapedTypeInterface::kDynamic + ? NewSymbolicDim() + : NewConstantSymbolicDim(dims[idx])); + } + return symbols; +} + +SymbolicDim SymbolicDimMgr::GetRootSymbolicDim(SymbolicDim symbol) { + SymbolicDim current = symbol; + std::vector path; + while (symbol_dim_union_set_[current] != current) { + path.push_back(current); + current = symbol_dim_union_set_[current]; + } + for (SymbolicDim sym : path) symbol_dim_union_set_[sym] = current; + return current; +} + +bool SymbolicDimMgr::IsSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) { + SymbolicDim lhs_root = GetRootSymbolicDim(lhs); + SymbolicDim rhs_root = GetRootSymbolicDim(rhs); + return lhs_root == rhs_root; +} + +bool SymbolicDimMgr::MapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) { + SymbolicDim lhs_root = GetRootSymbolicDim(lhs); + SymbolicDim rhs_root = GetRootSymbolicDim(rhs); + + if (lhs_root != rhs_root) { + if (CompareSymbolicDimNames(lhs_root.GetSymName(), rhs_root.GetSymName())) { + if (!lhs_root.Merge(rhs_root)) return false; + symbol_dim_union_set_[rhs_root] = lhs_root; + } else { + if (!rhs_root.Merge(lhs_root)) return false; + symbol_dim_union_set_[lhs_root] = rhs_root; + } + product_equality_map_updated_ = false; + } + return true; +} + +SymbolicDimProduct* SymbolicDimMgr::SymbolicDimProductDivide( + const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs) { + SymbolicDimProduct new_lhs, new_rhs; + std::tie(new_lhs, new_rhs) = SimplifySymbolicDimProductPair(lhs, rhs); + + if (new_lhs.factor == 0 || new_rhs.factor == 0) return nullptr; + if (new_lhs.factor % new_rhs.factor != 0) return nullptr; + if (new_lhs.symbols.size() < new_rhs.symbols.size()) return nullptr; + + SymbolicDimProduct* result = new SymbolicDimProduct(); + result->factor = new_lhs.factor / new_rhs.factor; + + std::unordered_map sym_proc_map; + for (SymbolicDim sym : new_rhs.symbols) ++sym_proc_map[sym]; + + for (SymbolicDim sym : new_lhs.symbols) { + auto it = sym_proc_map.find(sym); + if (it == sym_proc_map.end()) { + result->symbols.push_back(sym); + continue; + } + if (--it->second == 0) { + sym_proc_map.erase(it); + continue; + } + } + + if (!sym_proc_map.empty()) return nullptr; + return result; +} + +bool SymbolicDimMgr::IsMultipleOfKnownSymbolicDimProductEqualPair( + const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs) { + for (auto& pair_outter : product_equality_map_) { + const SymbolicDimProduct& x = pair_outter.first; + auto factor_x = SymbolicDimProductDivide(lhs, x); + if (!factor_x) continue; + for (auto& pair_inner : pair_outter.second) { + if (!pair_inner.second) continue; + const SymbolicDimProduct& y = pair_inner.first; + auto factor_y = SymbolicDimProductDivide(rhs, y); + if (!factor_y || (*factor_x) != (*factor_y)) continue; + return true; + } + } + + return false; +} + +bool SymbolicDimMgr::UpdateProductEqualityMap() { + // Return true if nothing is updated. + if (product_equality_map_updated_) return true; + + SymbolicDimProductMap new_map; + std::unordered_set product_set; + for (auto& pair_outter : product_equality_map_) { + const SymbolicDimProduct& x = pair_outter.first; + for (auto& pair_inner : pair_outter.second) { + if (!pair_inner.second) continue; + + const SymbolicDimProduct& y = pair_inner.first; + SymbolicDimProduct new_x, new_y; + std::tie(new_x, new_y) = SimplifySymbolicDimProductPair(x, y); + if (new_x == new_y) continue; + + new_map[new_x][new_y] = new_map[new_y][new_x] = true; + product_set.insert(new_x); + product_set.insert(new_y); + } + } + // hash function of SymbolicDimProduct is expensive, thus we map it to integer + // domain first. + std::unordered_map symProd2Idx; + std::vector idx2SymProd(product_set.size()); + std::vector idx2root(product_set.size()); + for (auto& x : product_set) { + size_t idx = symProd2Idx.size(); + symProd2Idx[&x] = idx; + idx2SymProd[idx] = &x; + idx2root[idx] = idx; + } + + auto getRootIdx = [&](size_t root) { + std::vector path; + while (idx2root[root] != root) { + path.push_back(root); + root = idx2root[root]; + } + for (size_t idx : path) idx2root[idx] = root; + return root; + }; + + for (size_t x = 0; x < symProd2Idx.size(); ++x) { + auto& xProd = *idx2SymProd[x]; + auto& rowMap = new_map[xProd]; + size_t xRoot = getRootIdx(x); + for (size_t y = x; y < symProd2Idx.size(); ++y) { + auto& yProd = *idx2SymProd[y]; + if (!rowMap[yProd]) continue; + idx2root[getRootIdx(y)] = xRoot; + } + } + + for (size_t x = 0; x < symProd2Idx.size(); ++x) + for (size_t y = x; y < symProd2Idx.size(); ++y) { + if (getRootIdx(x) != getRootIdx(y)) continue; + auto& xSymProd = *idx2SymProd[x]; + auto& ySymProd = *idx2SymProd[y]; + + new_map[xSymProd][ySymProd] = new_map[ySymProd][xSymProd] = true; + } + + product_equality_map_ = std::move(new_map); + + for (auto& x : product_set) + for (auto& y : product_set) { + if (!product_equality_map_[x][y]) continue; + product_equality_map_[x][y] = product_equality_map_[y][x] = false; + if (!IsMultipleOfKnownSymbolicDimProductEqualPair(x, y)) { + product_equality_map_[x][y] = product_equality_map_[y][x] = true; + } + } + + std::unordered_set toRemove; + for (auto& x : product_set) { + if (std::all_of(product_set.begin(), + product_set.end(), + [&](const SymbolicDimProduct& y) { + return !product_equality_map_[x][y]; + })) { + toRemove.insert(x); + } + } + + for (auto& x : toRemove) { + product_equality_map_.erase(x); + } + + product_equality_map_updated_ = true; + return true; +} + +bool SymbolicDimMgr::IsSymbolicDimProductEqual(const SymbolicDimProduct& lhs, + const SymbolicDimProduct& rhs) { + SymbolicDimProduct new_lhs, new_rhs; + std::tie(new_lhs, new_rhs) = SimplifySymbolicDimProductPair(lhs, rhs); + + // Return true for identity case. + if (new_lhs == new_rhs) return true; + IR_ENFORCE(UpdateProductEqualityMap(), "Update product equality map failed."); + return IsMultipleOfKnownSymbolicDimProductEqualPair(new_lhs, new_rhs); +} + +bool SymbolicDimMgr::Save() { + using Name2SymbolFn = std::function; + auto update_attrs = [&](ArrayAttribute attrs, Name2SymbolFn fn) { + std::vector new_attrs; + for (Attribute attr : attrs.AsVector()) { + auto sym = fn(attr.dyn_cast().AsString()); + assert(sym); + SymbolicDim root = GetRootSymbolicDim(sym); + Attribute root_symbol = + StrAttribute::get(m_->ir_context(), root.GetSymName()); + new_attrs.push_back(root_symbol); + } + return ArrayAttribute::get(m_->ir_context(), new_attrs); + }; + + // TODO(liujinnan): update attributes attached in DenseTensorType + for (auto op : *(m_.block())) { + if (!op->HasAttribute(SymbolicDim::GetSymbolicDimAttrName())) continue; + auto attrs = + op->attribute(SymbolicDim::GetSymbolicDimAttrName()); + auto symbolic_shape_attr = + update_attrs(attrs, [&](const std::string& name) { + return symbol_table_.Lookup(name); + }); + op->set_attribute(SymbolicDim::GetSymbolicDimAttrName(), + symbolic_shape_attr); + } + if (!UpdateProductEqualityMap()) { + return false; + } + std::unordered_set used_symbolic_ops; + std::vector used_symbol_names; + // TODO(liujinnan): collect uses in value. + auto collect_used_symbols = [&](ArrayAttribute attrs) { + for (Attribute attr : attrs.AsVector()) { + auto sym = symbol_table_.Lookup( + attr.dyn_cast().AsString()); + assert(sym); + if (used_symbolic_ops.insert(sym).second) + used_symbol_names.push_back(sym.GetSymName()); + } + }; + for (auto op : *(m_.block())) { + if (!op->HasAttribute(SymbolicDim::GetSymbolicDimAttrName())) continue; + auto attrs = + op->attribute(SymbolicDim::GetSymbolicDimAttrName()); + collect_used_symbols(attrs); + } + auto func_op = symbol_table_.getOp()->dyn_cast(); + assert(func_op); + for (auto& p : symbol_dim_union_set_) { + if (!used_symbolic_ops.count(p.first)) { + func_op.block()->erase(*(p.first.operation())); + } + } + + std::vector candidates; + for (auto& outter : product_equality_map_) { + if (std::any_of( + outter.first.symbols.begin(), + outter.first.symbols.end(), + [&](SymbolicDim sym) { return used_symbolic_ops.count(sym) == 0; })) + candidates.push_back(outter.first); + } + + for (auto& prod : candidates) product_equality_map_.erase(prod); + for (auto& outter : product_equality_map_) { + std::vector candidates; + for (auto& inner : outter.second) { + if (std::any_of(inner.first.symbols.begin(), + inner.first.symbols.end(), + [&](SymbolicDim sym) { + return used_symbolic_ops.count(sym) == 0; + })) + candidates.push_back(outter.first); + } + for (auto& prod : candidates) outter.second.erase(prod); + } + + std::sort(used_symbol_names.begin(), + used_symbol_names.end(), + [&](const std::string& lhs, const std::string& rhs) { + return CompareSymbolicDimNames(lhs, rhs); + }); + int non_const_dims_num = 0; + std::unordered_map name_mapping; + for (const auto& name : used_symbol_names) { + if (name.size() > 0 && name[0] == 'C') { + name_mapping[name] = name; + } else { + name_mapping[name] = ("S" + std::to_string(non_const_dims_num++)); + } + } + + std::unordered_map name_to_symbol; + for (SymbolicDim op : used_symbolic_ops) { + auto name = op.GetSymName(); + op.SetSymName(name_mapping[name]); + name_to_symbol[name] = op; + } + + for (auto op : *(m_.block())) { + if (!op->HasAttribute(SymbolicDim::GetSymbolicDimAttrName())) continue; + auto attrs = + op->attribute(SymbolicDim::GetSymbolicDimAttrName()); + auto symbolic_shape_attr = update_attrs( + attrs, [&](const std::string& name) { return name_to_symbol[name]; }); + op->set_attribute(SymbolicDim::GetSymbolicDimAttrName(), + symbolic_shape_attr); + } + + // TODO(liujinnan): update attributes attached to values. + + return SaveShapeConstraintGraph(); +} + +bool SymbolicDimMgr::SaveShapeConstraintGraph() { + auto func_op = symbol_table_.getOp()->dyn_cast(); + assert(func_op); + auto op_it = func_op.block()->rbegin(); + while (op_it != func_op.block()->rend()) { + if (((*op_it)->isa()) || + ((*op_it)->isa())) + op_it++; + else + op_it = decltype(op_it)(func_op.block()->erase(*(*op_it))); + } + + // save product equal predicate + Builder builder = Builder(m_->ir_context(), func_op.block()); + auto build_operands = [&](const SymbolicDimProduct& prod) { + std::vector values; + + if (prod.factor != 1) { + values.push_back( + builder + .Build( + Int32Attribute::get(m_->ir_context(), prod.factor), + Int32Type::get(m_->ir_context())) + ->result(0)); + } + for (SymbolicDim sym : prod.symbols) { + values.push_back(builder.Build(sym.GetSymName()).out()); + } + return values; + }; + std::vector sorted_product_vec; + for (auto& p : product_equality_map_) sorted_product_vec.push_back(p.first); + std::sort(sorted_product_vec.begin(), + sorted_product_vec.end(), + CompareSymbolicDimProduct); + for (auto& x : sorted_product_vec) { + for (auto& y : sorted_product_vec) { + if (!CompareSymbolicDimProduct(x, y)) continue; + if (!product_equality_map_[x][y]) continue; + auto lhs_operands = build_operands(x); + auto rhs_operands = build_operands(y); + builder.Build(lhs_operands, rhs_operands); + } + } + return true; +} +} // namespace pir diff --git a/paddle/pir/dialect/shape/utils/shape_optimization_utils.h b/paddle/pir/dialect/shape/utils/shape_optimization_utils.h index 7f31a4fb55cf1..fdec957aa6be7 100644 --- a/paddle/pir/dialect/shape/utils/shape_optimization_utils.h +++ b/paddle/pir/dialect/shape/utils/shape_optimization_utils.h @@ -13,3 +13,97 @@ // limitations under the License. #pragma once +#include +#include "paddle/pir/dialect/shape/utils/symbol_table.h" + +namespace pir { +using dialect::SymbolicDim; + +struct SymbolicDimProduct { + std::vector symbols; + int64_t factor = 1; + bool empty() { return factor == 1 && symbols.empty(); } + friend inline bool operator==(const SymbolicDimProduct& lhs, + const SymbolicDimProduct& rhs) { + return lhs.factor == rhs.factor && lhs.symbols == rhs.symbols; + } + + friend inline bool operator!=(const SymbolicDimProduct& lhs, + const SymbolicDimProduct& rhs) { + return !(lhs == rhs); + } +}; + +struct SymDimHasher { + size_t operator()(const dialect::SymbolicDim& symbol) const noexcept { + return std::hash{}(symbol.operation()); + } +}; + +struct SymProductHasher { + size_t operator()(const SymbolicDimProduct& symProd) const noexcept { + size_t hash = std::hash{}(symProd.symbols.size()); + for (auto& symbol : symProd.symbols) { + hash = hash_combine(hash, SymDimHasher{}(symbol)); // NOLINT + } + hash = hash_combine(hash, std::hash{}(symProd.factor)); + return hash; + } +}; + +class SymbolicDimMgr { + public: + explicit SymbolicDimMgr(ModuleOp m); + bool Load(); + SymbolicDim NewSymbolicDim(const std::string& name = {}); + SymbolicDim NewConstantSymbolicDim(int64_t val); + std::vector CreateSymbolicDimsForRankedValue(Value value); + SymbolicDim GetRootSymbolicDim(SymbolicDim symbol); + bool IsSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs); + bool MapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs); + SymbolicDimProduct SimplifySymbolicDimProduct(const SymbolicDimProduct& x); + std::pair + SimplifySymbolicDimProductPair(const SymbolicDimProduct& x, + const SymbolicDimProduct& y); + SymbolicDimProduct* SymbolicDimProductDivide(const SymbolicDimProduct& x, + const SymbolicDimProduct& y); + bool Save(); + bool IsSymbolicDimProductEqual(const SymbolicDimProduct& lhs, + const SymbolicDimProduct& rhs); + + bool MapSymbolicDimProductEqual(const SymbolicDimProduct& lhs, + const SymbolicDimProduct& rhs); + SymbolTable& symbolTable() { return symbol_table_; } + + private: + const std::string GetNextName(); + bool SaveShapeConstraintGraph(); + bool LoadShapeConstraintGraph(); + bool UpdateProductEqualityMap(); + bool IsMultipleOfKnownSymbolicDimProductEqualPair( + const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs); + + private: + ModuleOp m_; + + SymbolTable symbol_table_; + + int64_t next_symbolic_idx_ = 0; + + std::unordered_set symbol_name_set_; + + std::unordered_map + symbol_dim_union_set_; + + std::unordered_map constant_symbolic_dim_map_; + + // product_equality_map_[A][B] == true : Product[A] == Product[B] + using SymbolicDimProductMap = std::unordered_map< + SymbolicDimProduct, + std::unordered_map, + SymProductHasher>; + SymbolicDimProductMap product_equality_map_; + bool product_equality_map_updated_ = true; +}; + +} // namespace pir diff --git a/paddle/pir/dialect/shape/utils/shape_utils.cc b/paddle/pir/dialect/shape/utils/shape_utils.cc index ad2cc1d956918..4e4c87ed30f86 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.cc +++ b/paddle/pir/dialect/shape/utils/shape_utils.cc @@ -129,740 +129,6 @@ bool ShapeConstraintIRAnalysis::IsProductEqual(Value lhs, return mgr_.IsSymbolicDimProductEqual(lhs_prod, rhs_prod); } -// Gives a consistent order of a list op SymbolicDim Ops -bool CompareSymbolicDimNames(const std::string& lhs, const std::string& rhs) { - // S -> unknown dimension size at compile time - // C -> constant dimension size at compile time - if (lhs.size() < 1 || (lhs[0] != 'S' && lhs[0] != 'C')) return lhs < rhs; - if (rhs.size() < 1 || (rhs[0] != 'S' && rhs[0] != 'C')) return lhs < rhs; - int64_t lhs_idx = 0, rhs_idx = 0; - try { - lhs_idx = stol(lhs.substr(1)); - rhs_idx = stol(rhs.substr(1)); - } catch (const std::exception& e) { - IR_THROW("Invalid symbolic name"); - } - return (lhs[0] < rhs[0]) || (lhs[0] == rhs[0] && lhs_idx < rhs_idx); -} - -// Gives a consistent order of a list op SymbolicDimProducts -bool CompareSymbolicDimProduct(SymbolicDimProduct& lhs, // NOLINT - SymbolicDimProduct& rhs) { // NOLINT - if (lhs.symbols.size() < rhs.symbols.size()) return true; - if (lhs.symbols.size() == rhs.symbols.size()) { - for (size_t idx = 0; idx < lhs.symbols.size(); ++idx) { - const std::string lhs_name = lhs.symbols[idx].GetSymName(); - const std::string rhs_name = rhs.symbols[idx].GetSymName(); - if (CompareSymbolicDimNames(lhs_name, rhs_name)) return true; - if (lhs_name != rhs_name) return false; - } - } - return false; -} - -bool SymbolicDimMgr::Load() { - auto func_op = symbol_table_.getOp()->dyn_cast(); - assert(func_op); - for (auto op_ : *(func_op.block())) { - symbol_table_.insert(op_); - if (SymbolicDim op = op_->dyn_cast()) { - symbolDimUnionSet_[op] = op; - symbolNameSet_.insert(op.GetSymName()); - } - } - return LoadShapeConstraintGraph(); -} - -bool SymbolicDimMgr::LoadShapeConstraintGraph() { - // TODO(liujinnan): add more constraint function. currently, only support - // tie_product_equal. - auto constraint_vec = - symbol_table_.Lookup("tie_product_equal"); - - if (!constraint_vec.size()) return true; - - auto build_sym_product = [&](std::vector range, - SymbolicDimProduct& product) { - for (Value v : range) { - auto definingOp = v.dyn_cast().owner(); - if (auto constOp = definingOp->dyn_cast()) { - product.factor *= constOp.value().dyn_cast().data(); - continue; - } else if (auto dimOp = definingOp->dyn_cast()) { - auto sym = symbol_table_.Lookup(dimOp.getName()); - if (!sym) return false; - product.symbols.push_back(sym); - continue; - } - return false; - } - return true; - }; - - for (auto op : constraint_vec) { - SymbolicDimProduct lhs, rhs; - if (!build_sym_product(op.lhs(), lhs) || - !build_sym_product(op.rhs(), rhs) || - !MapSymbolicDimProductEqual(lhs, rhs)) - return false; - } - return true; -} - -bool SymbolicDimMgr::MapSymbolicDimProductEqual(const SymbolicDimProduct& lhs, - const SymbolicDimProduct& rhs) { - SymbolicDimProduct new_lhs, new_rhs; - std::tie(new_lhs, new_rhs) = SimplifySymbolicDimProductPair(lhs, rhs); - - // early return for identity case. - if (new_lhs == new_rhs) return true; - - if (new_lhs.factor == new_rhs.factor && new_lhs.symbols.size() == 1 && - new_rhs.symbols.size() == 1) { - return MapSymbolicDimEqual(new_lhs.symbols[0], new_rhs.symbols[0]); - } else if (new_lhs.symbols.size() == 0 && new_rhs.symbols.size() == 1 && - new_rhs.factor == 1) { - return MapSymbolicDimEqual(NewConstantSymbolicDim(new_lhs.factor), - new_rhs.symbols[0]); - } else if (new_rhs.symbols.size() == 0 && new_lhs.symbols.size() == 1 && - new_lhs.factor == 1) { - return MapSymbolicDimEqual(NewConstantSymbolicDim(new_rhs.factor), - new_lhs.symbols[0]); - } - - productEqualityMap_[new_lhs][new_rhs] = - productEqualityMap_[new_rhs][new_lhs] = true; - - productEqualityMapUpdated_ = false; - return true; -} - -std::pair -SymbolicDimMgr::SimplifySymbolicDimProductPair(const SymbolicDimProduct& x, - const SymbolicDimProduct& y) { - auto lhs = SimplifySymbolicDimProduct(x); - auto rhs = SimplifySymbolicDimProduct(y); - - SymbolicDimProduct new_lhs, new_rhs; - int64_t gcd_factor = std::gcd(std::abs(lhs.factor), std::abs(rhs.factor)); - if (!gcd_factor) - return std::make_pair(std::move(new_lhs), std::move(new_rhs)); - if (std::abs(lhs.factor) < std::abs(rhs.factor)) { - if (lhs.factor < 0) gcd_factor = -gcd_factor; - } else { - if (rhs.factor < 0) gcd_factor = -gcd_factor; - } - - new_lhs.factor = lhs.factor / gcd_factor; - new_rhs.factor = rhs.factor / gcd_factor; - - std::unordered_map lhs_symbol_map; - std::unordered_map rhs_symbol_map; - for (SymbolicDim op : lhs.symbols) ++lhs_symbol_map[op]; - for (SymbolicDim op : rhs.symbols) ++rhs_symbol_map[op]; - - for (SymbolicDim op : lhs.symbols) { - auto it = rhs_symbol_map.find(op); - if (it != rhs_symbol_map.end() && op.GetKnownNonSizeZero()) { - if (--it->second == 0) rhs_symbol_map.erase(it); - continue; - } - new_lhs.symbols.push_back(op); - } - - for (SymbolicDim op : rhs.symbols) { - auto it = lhs_symbol_map.find(op); - if (it != lhs_symbol_map.end() && op.GetKnownNonSizeZero()) { - if (--it->second == 0) lhs_symbol_map.erase(it); - continue; - } - new_rhs.symbols.push_back(op); - } - - if (!new_lhs.factor) new_lhs.symbols.clear(); - if (!new_rhs.factor) new_rhs.symbols.clear(); - - return std::make_pair(std::move(new_lhs), std::move(new_rhs)); -} - -SymbolicDimProduct SymbolicDimMgr::SimplifySymbolicDimProduct( - const SymbolicDimProduct& x) { - std::vector copied; - copied.reserve(x.symbols.size()); - for (SymbolicDim op : x.symbols) copied.push_back(GetRootSymbolicDim(op)); - - sort(copied.begin(), copied.end(), [&](SymbolicDim lhs, SymbolicDim rhs) { - return CompareSymbolicDimNames(lhs.GetSymName(), rhs.GetSymName()); - }); - SymbolicDimProduct newX; - newX.factor = x.factor; - for (SymbolicDim op : copied) { - if (!op.IsDynamic()) { - newX.factor *= op.GetDimSize(); - } else { - newX.symbols.push_back(op); - } - } - return newX; -} - -const std::string SymbolicDimMgr::GetNextName() { - std::string name; - do { - name = "S" + std::to_string(nextSymbolicIdx_++); - } while (!symbolNameSet_.insert(name).second); - return name; -} - -SymbolicDimMgr::SymbolicDimMgr(ModuleOp m) : m_(m) { - for (auto op : *(m.block())) { - if (op->isa()) { - symbol_table_ = SymbolTable(op); - return; - } - } - Builder builder = Builder(m_.ir_context(), m_.block(), m_.block()->begin()); - dialect::FuncOp func = builder.Build(); - symbol_table_ = SymbolTable(func); -} - -SymbolicDim SymbolicDimMgr::NewSymbolicDim(const std::string& name) { - auto func_op = symbol_table_.getOp()->dyn_cast(); - assert(func_op); - Builder builder = Builder(m_.ir_context(), func_op.block()); - // default settting dim != 0 - dialect::SymbolicDim symbol = - builder.Build(name.empty() ? GetNextName() : name, - ShapedTypeInterface::kDynamic, - false, - false, - false, - true); - symbolDimUnionSet_[symbol] = symbol; - symbol_table_.insert(symbol); - return symbol; -} - -SymbolicDim SymbolicDimMgr::NewConstantSymbolicDim(int64_t val) { - auto it = constantSymbolicDimMap_.find(val); - if (it == constantSymbolicDimMap_.end()) { - auto name = "C" + std::to_string(val); - it = constantSymbolicDimMap_ - .insert(std::make_pair(val, NewSymbolicDim(name))) - .first; - it->second.SetDimSize(val); - if (val == -1) it->second.UpdateKnownNegativeOne(true); - if (val >= 0) it->second.UpdateKnownNonNegative(true); - if (val != 1) it->second.UpdateKnownNonSizeOne(true); - if (val != 0) it->second.UpdateKnownNonSizeZero(true); - } - return GetRootSymbolicDim(it->second); -} - -std::vector SymbolicDimMgr::CreateSymbolicDimsForRankedValue( - Value value) { - std::vector symbols; - auto dims = value.type().dyn_cast().dims(); - for (int idx = 0; idx < dims.size(); ++idx) { - symbols.push_back(dims[idx] == ShapedTypeInterface::kDynamic - ? NewSymbolicDim() - : NewConstantSymbolicDim(dims[idx])); - } - return symbols; -} - -SymbolicDim SymbolicDimMgr::GetRootSymbolicDim(SymbolicDim symbol) { - SymbolicDim current = symbol; - std::vector path; - while (symbolDimUnionSet_[current] != current) { - path.push_back(current); - current = symbolDimUnionSet_[current]; - } - for (SymbolicDim sym : path) symbolDimUnionSet_[sym] = current; - return current; -} - -bool SymbolicDimMgr::IsSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) { - SymbolicDim lhsRoot = GetRootSymbolicDim(lhs); - SymbolicDim rhsRoot = GetRootSymbolicDim(rhs); - return lhsRoot == rhsRoot; -} - -bool SymbolicDimMgr::MapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) { - SymbolicDim lhsRoot = GetRootSymbolicDim(lhs); - SymbolicDim rhsRoot = GetRootSymbolicDim(rhs); - - if (lhsRoot != rhsRoot) { - if (CompareSymbolicDimNames(lhsRoot.GetSymName(), rhsRoot.GetSymName())) { - if (!lhsRoot.Merge(rhsRoot)) return false; - symbolDimUnionSet_[rhsRoot] = lhsRoot; - } else { - if (!rhsRoot.Merge(lhsRoot)) return false; - symbolDimUnionSet_[lhsRoot] = rhsRoot; - } - } - return true; -} - -SymbolicDimProduct* SymbolicDimMgr::SymbolicDimProductDivide( - const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs) { - SymbolicDimProduct new_lhs, new_rhs; - std::tie(new_lhs, new_rhs) = SimplifySymbolicDimProductPair(lhs, rhs); - - if (new_lhs.factor == 0 || new_rhs.factor == 0) return nullptr; - if (new_lhs.factor % new_rhs.factor != 0) return nullptr; - if (new_lhs.symbols.size() < new_rhs.symbols.size()) return nullptr; - - SymbolicDimProduct* result = new SymbolicDimProduct(); - result->factor = new_lhs.factor / new_rhs.factor; - - std::unordered_map sym_proc_map; - for (SymbolicDim sym : new_rhs.symbols) ++sym_proc_map[sym]; - - for (SymbolicDim sym : new_lhs.symbols) { - auto it = sym_proc_map.find(sym); - if (it == sym_proc_map.end()) { - result->symbols.push_back(sym); - continue; - } - if (--it->second == 0) { - sym_proc_map.erase(it); - continue; - } - } - - if (!sym_proc_map.empty()) return nullptr; - return result; -} - -bool SymbolicDimMgr::IsMultipleOfKnownSymbolicDimProductEqualPair( - const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs) { - for (auto& pairOutter : productEqualityMap_) { - const SymbolicDimProduct& x = pairOutter.first; - auto factorX = SymbolicDimProductDivide(lhs, x); - if (!factorX) continue; - for (auto& pairInner : pairOutter.second) { - if (!pairInner.second) continue; - const SymbolicDimProduct& y = pairInner.first; - auto factorY = SymbolicDimProductDivide(rhs, y); - if (!factorY || (*factorX) != (*factorY)) continue; - return true; - } - } - - return false; -} - -bool SymbolicDimMgr::UpdateProductEqualityMap() { - // early return if nothing is updated. - if (productEqualityMapUpdated_) return true; - - SymbolicDimProductMap newMap; - std::unordered_set productSet; - for (auto& pairOutter : productEqualityMap_) { - const SymbolicDimProduct& x = pairOutter.first; - for (auto& pairInner : pairOutter.second) { - if (!pairInner.second) continue; - const SymbolicDimProduct& y = pairInner.first; - SymbolicDimProduct newX, newY; - std::tie(newX, newY) = SimplifySymbolicDimProductPair(x, y); - if (newX == newY) continue; - newMap[newX][newY] = newMap[newY][newX] = true; - productSet.insert(newX); - productSet.insert(newY); - } - } - // hash function of SymbolicDimProduct is expensive, thus we map it to integer - // domain first. - std::unordered_map symProd2Idx; - std::vector idx2SymProd(productSet.size()); - std::vector idx2root(productSet.size()); - for (auto& x : productSet) { - size_t idx = symProd2Idx.size(); - symProd2Idx[&x] = idx; - idx2SymProd[idx] = &x; - idx2root[idx] = idx; - } - - auto getRootIdx = [&](size_t root) { - std::vector path; - while (idx2root[root] != root) { - path.push_back(root); - root = idx2root[root]; - } - for (size_t idx : path) idx2root[idx] = root; - return root; - }; - - for (size_t x = 0; x < symProd2Idx.size(); ++x) { - auto& xProd = *idx2SymProd[x]; - auto& rowMap = newMap[xProd]; - size_t xRoot = getRootIdx(x); - for (size_t y = x; y < symProd2Idx.size(); ++y) { - auto& yProd = *idx2SymProd[y]; - if (!rowMap[yProd]) continue; - idx2root[getRootIdx(y)] = xRoot; - } - } - - for (size_t x = 0; x < symProd2Idx.size(); ++x) - for (size_t y = x; y < symProd2Idx.size(); ++y) { - if (getRootIdx(x) != getRootIdx(y)) continue; - auto& xSymProd = *idx2SymProd[x]; - auto& ySymProd = *idx2SymProd[y]; - - newMap[xSymProd][ySymProd] = newMap[ySymProd][xSymProd] = true; - } - - productEqualityMap_ = std::move(newMap); - - for (auto& x : productSet) - for (auto& y : productSet) { - if (!productEqualityMap_[x][y]) continue; - productEqualityMap_[x][y] = productEqualityMap_[y][x] = false; - if (!IsMultipleOfKnownSymbolicDimProductEqualPair(x, y)) { - productEqualityMap_[x][y] = productEqualityMap_[y][x] = true; - } - } - - std::unordered_set toRemove; - for (auto& x : productSet) { - if (std::all_of(productSet.begin(), - productSet.end(), - [&](const SymbolicDimProduct& y) { - return !productEqualityMap_[x][y]; - })) { - toRemove.insert(x); - } - } - - for (auto& x : toRemove) { - productEqualityMap_.erase(x); - } - - productEqualityMapUpdated_ = true; - return true; -} - -bool SymbolicDimMgr::IsSymbolicDimProductEqual(const SymbolicDimProduct& lhs, - const SymbolicDimProduct& rhs) { - SymbolicDimProduct new_lhs, new_rhs; - std::tie(new_lhs, new_rhs) = SimplifySymbolicDimProductPair(lhs, rhs); - - // early return for identity case. - if (new_lhs == new_rhs) return true; - IR_ENFORCE(UpdateProductEqualityMap(), "Update product equality map failed."); - return IsMultipleOfKnownSymbolicDimProductEqualPair(new_lhs, new_rhs); -} - -bool SymbolicDimMgr::Save() { - using Name2SymbolFn = std::function; - auto updateAttrs = [&](ArrayAttribute attrs, Name2SymbolFn fn) { - std::vector newAttrs; - for (Attribute attr : attrs.AsVector()) { - auto sym = fn(attr.dyn_cast().AsString()); - assert(sym); - SymbolicDim root = GetRootSymbolicDim(sym); - Attribute rootSymbol = - StrAttribute::get(m_->ir_context(), root.GetSymName()); - newAttrs.push_back(rootSymbol); - } - return ArrayAttribute::get(m_->ir_context(), newAttrs); - }; - - // TODO(liujinnan): update attributes attached in DenseTensorType - for (auto op : *(m_.block())) { - if (!op->HasAttribute(SymbolicDim::GetSymbolicDimAttrName())) continue; - auto attrs = - op->attribute(SymbolicDim::GetSymbolicDimAttrName()); - auto symbolicShapeAttr = updateAttrs(attrs, [&](const std::string& name) { - return symbol_table_.Lookup(name); - }); - op->set_attribute(SymbolicDim::GetSymbolicDimAttrName(), symbolicShapeAttr); - } - if (!UpdateProductEqualityMap()) { - return false; - } - std::unordered_set usedSymbolicOps; - std::vector usedSymbolNames; - // TODO(liujinnan): collect uses in value. - auto collectUsedSymbols = [&](ArrayAttribute attrs) { - for (Attribute attr : attrs.AsVector()) { - auto sym = symbol_table_.Lookup( - attr.dyn_cast().AsString()); - assert(sym); - if (usedSymbolicOps.insert(sym).second) - usedSymbolNames.push_back(sym.GetSymName()); - } - }; - for (auto op : *(m_.block())) { - if (!op->HasAttribute(SymbolicDim::GetSymbolicDimAttrName())) continue; - auto attrs = - op->attribute(SymbolicDim::GetSymbolicDimAttrName()); - collectUsedSymbols(attrs); - } - auto func_op = symbol_table_.getOp()->dyn_cast(); - assert(func_op); - for (auto& p : symbolDimUnionSet_) { - if (!usedSymbolicOps.count(p.first)) { - func_op.block()->erase(*(p.first.operation())); - } - } - - std::vector candidates; - for (auto& outter : productEqualityMap_) { - if (std::any_of( - outter.first.symbols.begin(), - outter.first.symbols.end(), - [&](SymbolicDim sym) { return usedSymbolicOps.count(sym) == 0; })) - candidates.push_back(outter.first); - } - - for (auto& prod : candidates) productEqualityMap_.erase(prod); - for (auto& outter : productEqualityMap_) { - std::vector candidates; - for (auto& inner : outter.second) { - if (std::any_of( - inner.first.symbols.begin(), - inner.first.symbols.end(), - [&](SymbolicDim sym) { return usedSymbolicOps.count(sym) == 0; })) - candidates.push_back(outter.first); - } - for (auto& prod : candidates) outter.second.erase(prod); - } - - std::sort(usedSymbolNames.begin(), - usedSymbolNames.end(), - [&](const std::string& lhs, const std::string& rhs) { - return CompareSymbolicDimNames(lhs, rhs); - }); - int numNonConstDims = 0; - std::unordered_map nameMapping; - for (const auto& name : usedSymbolNames) { - if (name.size() > 0 && name[0] == 'C') { - nameMapping[name] = name; - } else { - nameMapping[name] = ("S" + std::to_string(numNonConstDims++)); - } - } - - std::unordered_map name2Symbol; - for (SymbolicDim op : usedSymbolicOps) { - auto name = op.GetSymName(); - op.SetSymName(nameMapping[name]); - name2Symbol[name] = op; - } - - for (auto op : *(m_.block())) { - if (!op->HasAttribute(SymbolicDim::GetSymbolicDimAttrName())) continue; - auto attrs = - op->attribute(SymbolicDim::GetSymbolicDimAttrName()); - auto symbolicShapeAttr = updateAttrs( - attrs, [&](const std::string& name) { return name2Symbol[name]; }); - op->set_attribute(SymbolicDim::GetSymbolicDimAttrName(), symbolicShapeAttr); - } - - // TODO(liujinnan): update attributes attached to values. - - return SaveShapeConstraintGraph(); -} - -bool SymbolicDimMgr::SaveShapeConstraintGraph() { - auto func_op = symbol_table_.getOp()->dyn_cast(); - assert(func_op); - auto op_it = func_op.block()->rbegin(); - while (op_it != func_op.block()->rend()) { - if (((*op_it)->isa()) || - ((*op_it)->isa())) - op_it++; - else - op_it = decltype(op_it)(func_op.block()->erase(*(*op_it))); - } - - Builder builder = Builder(m_->ir_context(), func_op.block()); - auto build_operands = [&](const SymbolicDimProduct& prod) { - std::vector values; - - if (prod.factor != 1) { - values.push_back( - builder - .Build( - Int32Attribute::get(m_->ir_context(), prod.factor), - Int32Type::get(m_->ir_context())) - ->result(0)); - } - for (SymbolicDim sym : prod.symbols) { - values.push_back(builder.Build(sym.GetSymName()).out()); - } - return values; - }; - std::vector sortedProductVec; - for (auto& p : productEqualityMap_) sortedProductVec.push_back(p.first); - std::sort(sortedProductVec.begin(), - sortedProductVec.end(), - CompareSymbolicDimProduct); - for (auto& x : sortedProductVec) { - for (auto& y : sortedProductVec) { - if (!CompareSymbolicDimProduct(x, y)) continue; - if (!productEqualityMap_[x][y]) continue; - auto lhsOperands = build_operands(x); - auto rhsOperands = build_operands(y); - builder.Build(lhsOperands, rhsOperands); - } - } - return true; -} - -ShapeComputationIRAnalysis::ShapeComputationIRAnalysis(ModuleOp m, - SymbolicDimMgr& mgr) - : m_(m), mgr_(mgr) {} - -bool ShapeComputationIRAnalysis::Run() { - // Make sure only run once. - if (initialized_) return false; - initialized_ = true; - auto buildShapeFunc = - std::bind(&ShapeComputationIRAnalysis::BuildShapeOnOperation, - this, - std::placeholders::_1); - if (!RunOnRegion(&(m_->region(0)), buildShapeFunc)) return false; - auto applyOpConstraintFunc = - std::bind(&ShapeComputationIRAnalysis::ApplyOpConstraint, - this, - std::placeholders::_1); - if (!RunOnRegion(&(m_->region(0)), applyOpConstraintFunc)) return false; - return true; -} - -bool ShapeComputationIRAnalysis::RunOnRegion(Region* region, func fn) { - for (Block* block : *region) { - if (!RunOnBlock(block, fn)) return false; - } - return true; -} - -bool ShapeComputationIRAnalysis::RunOnBlock(Block* block, func fn) { - // TODO(liujinnan): mapping block arguments - - std::vector op_list; - for (Operation* op : *block) op_list.push_back(op); - for (Operation* op : op_list) { - if (!RunOnOperation(op, fn)) return false; - } - return true; -} - -bool ShapeComputationIRAnalysis::RunOnOperation(Operation* op, func fn) { - for (size_t i = 0; i < op->num_regions(); ++i) { - if (!RunOnRegion(&(op->region(i)), fn)) return false; - } - return fn(op); -} - -bool ShapeComputationIRAnalysis::BuildShapeOnOperation(Operation* op) { - if (op->isa()) return true; - if (op->isa()) { - Value value = op->operand_source(0); - std::vector symbols; - if (op->HasAttribute(SymbolicDim::GetSymbolicDimAttrName())) { - auto attrs = - op->attribute(SymbolicDim::GetSymbolicDimAttrName()) - .AsVector(); - for (Attribute attr : attrs) { - auto sym = mgr_.symbolTable().Lookup( - attr.dyn_cast().AsString()); - assert(sym); - SymbolicDim root = mgr_.GetRootSymbolicDim(sym); - symbols.push_back(root); - } - } else { - symbols = mgr_.CreateSymbolicDimsForRankedValue(value); - std::vector attrs; - for (SymbolicDim sym : symbols) { - Attribute rootSymbol = - StrAttribute::get(m_->ir_context(), sym.GetSymName()); - attrs.push_back(rootSymbol); - } - op->set_attribute(SymbolicDim::GetSymbolicDimAttrName(), - ArrayAttribute::get(m_->ir_context(), attrs)); - } - rankedTensor2SymDims_[value] = std::move(symbols); - return true; - } - for (size_t i = 0; i < op->num_results(); ++i) { - if (!BuildShapeOnValue(op->result(i))) return false; - } - return true; -} - -bool ShapeComputationIRAnalysis::BuildShapeOnValue(Value value) { - Type type = value.type(); - if (IsIntOrIndex(type)) { - SymbolicDim sym = mgr_.NewSymbolicDim(); - value2SymDim_[value] = sym; - } else if (IsCandidateShapeTensorType(type)) { - auto shapedTy = type.dyn_cast(); - std::vector symbols; - for (size_t i = 0, d = shapedTy.GetShape()[0]; i < d; ++i) - symbols.push_back(mgr_.NewSymbolicDim()); - shapeTensor2SymDims_[value] = std::move(symbols); - } - return true; -} - -bool ShapeComputationIRAnalysis::ApplyOpConstraint(Operation* op) { - IR_ENFORCE(ApplyIndexOpConstraint(op), - "Fail to apply constraint for index op"); - IR_ENFORCE(ApplyTieShapeOpConstraint(op), - "Fail to apply constraint for tie_shape op"); - - // TODO(zhangbo63): add more constraints - return true; -} - -bool ShapeComputationIRAnalysis::ApplyIndexOpConstraint(Operation* op) { - if (op->num_results() == 0) return true; - - Type type = op->result(0).type(); - if (!IsIntOrIndex(type)) return true; - - if (auto dimOp = op->dyn_cast()) { - int64_t dimIndex = dimOp.index() - .dyn_cast() - .owner() - ->attribute("value") - .data(); - value2SymDim_[dimOp.out()].UpdateKnownNonNegative(true); - if (!mgr_.MapSymbolicDimEqual( - value2SymDim_[dimOp.out()], - rankedTensor2SymDims_[dimOp.source()][dimIndex])) { - return false; - } - - } else if (auto constOp = op->dyn_cast()) { - int64_t val = constOp.value().dyn_cast().data(); - if (!mgr_.MapSymbolicDimEqual(value2SymDim_[op->result(0)], - mgr_.NewConstantSymbolicDim(val))) { - return false; - } - } - // TODO(zhangbo63): add support for reifyInferShape. (e.g. mul/add) - return true; -} - -bool ShapeComputationIRAnalysis::ApplyTieShapeOpConstraint(Operation* op) { - if (auto tieShape = op->dyn_cast()) { - auto& value = rankedTensor2SymDims_[op->operand_source(0)]; - for (size_t idx = 0; idx < tieShape.dims().size(); ++idx) { - if (!mgr_.MapSymbolicDimEqual(value2SymDim_[tieShape.dims()[idx]], - value[idx])) - return false; - mgr_.GetRootSymbolicDim(value[idx]).UpdateKnownNonNegative(true); - } - } - return true; -} - bool IsIntOrIndex(Type type) { return type.isa() || type.isa() || type.isa() || type.isa() || diff --git a/paddle/pir/dialect/shape/utils/shape_utils.h b/paddle/pir/dialect/shape/utils/shape_utils.h index 3388971d32aac..72510f8a23c83 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.h +++ b/paddle/pir/dialect/shape/utils/shape_utils.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/pir/dialect/shape/utils/shape_optimization_utils.h" #include "paddle/pir/dialect/shape/utils/symbol_table.h" namespace pir { @@ -49,101 +50,10 @@ class ShapeAnalysis { using dialect::SymbolicDim; -struct SymbolicDimProduct { - std::vector symbols; - int64_t factor = 1; - bool empty() { return factor == 1 && symbols.empty(); } - friend inline bool operator==(const SymbolicDimProduct& lhs, - const SymbolicDimProduct& rhs) { - return lhs.factor == rhs.factor && lhs.symbols == rhs.symbols; - } - - friend inline bool operator!=(const SymbolicDimProduct& lhs, - const SymbolicDimProduct& rhs) { - return !(lhs == rhs); - } -}; - -struct SymDimHasher { - size_t operator()(const dialect::SymbolicDim& symbol) const noexcept { - return std::hash{}(symbol.operation()); - } -}; - -struct SymProductHasher { - size_t operator()(const SymbolicDimProduct& symProd) const noexcept { - size_t hash = std::hash{}(symProd.symbols.size()); - for (auto& symbol : symProd.symbols) { - hash = hash_combine(hash, SymDimHasher{}(symbol)); // NOLINT - } - hash = hash_combine(hash, std::hash{}(symProd.factor)); - return hash; - } -}; - -class SymbolicDimMgr { - public: - explicit SymbolicDimMgr(ModuleOp m); - bool Load(); - SymbolicDim NewSymbolicDim(const std::string& name = {}); - SymbolicDim NewConstantSymbolicDim(int64_t val); - std::vector CreateSymbolicDimsForRankedValue(Value value); - SymbolicDim GetRootSymbolicDim(SymbolicDim symbol); - bool IsSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs); - SymbolTable& symbolTable() { return symbol_table_; } - bool MapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs); - SymbolicDimProduct SimplifySymbolicDimProduct(const SymbolicDimProduct& x); - std::pair - SimplifySymbolicDimProductPair(const SymbolicDimProduct& x, - const SymbolicDimProduct& y); - SymbolicDimProduct* SymbolicDimProductDivide(const SymbolicDimProduct& x, - const SymbolicDimProduct& y); - - bool Save(); - - bool IsSymbolicDimProductEqual(const SymbolicDimProduct& lhs, - const SymbolicDimProduct& rhs); - bool MapSymbolicDimProductEqual(const SymbolicDimProduct& lhs, - const SymbolicDimProduct& rhs); - - private: - const std::string GetNextName(); - bool UpdateProductEqualityMap(); - bool IsMultipleOfKnownSymbolicDimProductEqualPair( - const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs); - bool SaveShapeConstraintGraph(); - bool LoadShapeConstraintGraph(); - - private: - ModuleOp m_; - - SymbolTable symbol_table_; - - int64_t nextSymbolicIdx_ = 0; - - std::unordered_set symbolNameSet_; - - std::unordered_map symbolDimUnionSet_; - - std::unordered_map constantSymbolicDimMap_; - - // productEqualityMap_[A][B] == true : Product[A] == Product[B] - using SymbolicDimProductMap = std::unordered_map< - SymbolicDimProduct, - std::unordered_map, - SymProductHasher>; - SymbolicDimProductMap productEqualityMap_; - bool productEqualityMapUpdated_ = true; -}; - // A subclass to impement `ShapeAnalysis` on buffer level. // The implementation is based on shape constraint ir. class ShapeConstraintIRAnalysis : public ShapeAnalysis { public: - // Build shape related analysis on the provided `op`. - // This generally can be divided into two steps: - // 1, load exsiting shape constraint ir (e.g. symbolic dim ops) - // 2, build mapping between memref values and symbolic dim ops. explicit ShapeConstraintIRAnalysis(ModuleOp m); // auto-save updated shape constriant ir when destroying. @@ -156,12 +66,6 @@ class ShapeConstraintIRAnalysis : public ShapeAnalysis { // Returns true if the two value have the same symbolic shape. bool IsShapeEqual(Value lhs, Value rhs) override; - // Suppose: - // lhs_dim_idxs = {ld0, ld1, ...} - // rhs_dim_idxs = {rd0, rd1, ...} - // Returns true if: - // lhs.shape[ld0] * lhs.shape[ld1] * ... == - // rhs.shape[rd0] * rhs.shape[rd1] * ... bool IsProductEqual(Value lhs, std::vector lhs_dim_idxs, Value rhs, @@ -177,37 +81,6 @@ class ShapeConstraintIRAnalysis : public ShapeAnalysis { std::unordered_map> value_to_sym_dims_; }; -class ShapeComputationIRAnalysis { - public: - using func = std::function; - explicit ShapeComputationIRAnalysis(ModuleOp m, - SymbolicDimMgr& mgr); // NOLINT - bool Run(); - - private: - bool RunOnRegion(Region* region, func fn); - bool RunOnBlock(Block* block, func fn); - bool RunOnOperation(Operation* op, func fn); - - bool BuildShapeOnOperation(Operation* op); - bool BuildShapeOnValue(Value value); - - bool ApplyOpConstraint(Operation* op); - bool ApplyIndexOpConstraint(Operation* op); - bool ApplyTieShapeOpConstraint(Operation* op); - - bool initialized_ = false; - ModuleOp m_; - SymbolicDimMgr& mgr_; - - std::unordered_map value2SymDim_; - - // shape tensor is the 1D ranked tensor with int/index dtype. - std::unordered_map> shapeTensor2SymDims_; - - std::unordered_map> rankedTensor2SymDims_; -}; - bool IsIntOrIndex(Type type); bool IsCandidateShapeTensorType(Type ty); } // namespace pir diff --git a/test/cpp/pir/shape_dialect/constraint_pass_test.cc b/test/cpp/pir/shape_dialect/constraint_pass_test.cc index 7c645044a09d0..f5282727f7250 100644 --- a/test/cpp/pir/shape_dialect/constraint_pass_test.cc +++ b/test/cpp/pir/shape_dialect/constraint_pass_test.cc @@ -39,6 +39,7 @@ #include "paddle/pir/core/value.h" #include "paddle/pir/dialect/shape/ir/shape_dialect.h" #include "paddle/pir/dialect/shape/ir/shape_op.h" +#include "paddle/pir/dialect/shape/transforms/shape_optimization.h" #include "paddle/pir/dialect/shape/transforms/shape_optimization_pass.h" #include "paddle/pir/dialect/shape/utils/shape_utils.h" #include "paddle/pir/pass/pass.h"