Skip to content

Commit

Permalink
[CPU] Fix shape mismatching in fusing per channel (openvinotoolkit#11162
Browse files Browse the repository at this point in the history
)

* Fix shape mismatching in fusing per channel

* channelAxis data type changes to int
  • Loading branch information
xuchen-intel authored May 20, 2022
1 parent 35ba009 commit 8886d0f
Show file tree
Hide file tree
Showing 17 changed files with 143 additions and 69 deletions.
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/graph_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1864,7 +1864,7 @@ void GraphOptimizer::FusePerformedAsScaleShiftAndFakeQuantize(Graph &graph) {

const auto &outputShape = child->getOutputShapeAtPort(0);
VectorDims outputDims = outputShape.getDims();
const size_t channelPos = parent->getParentEdgeAt(0)->getParent()->getFusingAxis();
const auto channelPos = parent->getParentEdgeAt(0)->getParent()->getFusingAxis();

if (outputShape.isDynamic()) {
if (outputDims[channelPos] == Shape::UNDEFINED_DIM) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ bool SupportsFusingWithConvolution_SumActivation(const std::shared_ptr<const Nod
ov::is_type<ngraph::op::v5::Round>(node);
}

bool canBePerformedAsScaleShift(const std::shared_ptr<const Node> &node, const size_t channelAxis) {
bool canBePerformedAsScaleShift(const std::shared_ptr<const Node> &node, const int channelAxis) {
size_t fusingPort = 0;
size_t numNonConstInputs = 0;
ov::PartialShape dataShape;
Expand Down Expand Up @@ -110,7 +110,7 @@ bool canBePerformedAsScaleShift(const std::shared_ptr<const Node> &node, const s
isBroadcastableToDataInput();
}

bool SupportsFusingWithConvolution_Simple(const std::shared_ptr<const Node> &node, const size_t channelAxis = 1) {
bool SupportsFusingWithConvolution_Simple(const std::shared_ptr<const Node> &node, const int channelAxis = 1) {
return SupportsFusingWithConvolution_SumActivation(node) ||
ov::is_type<ngraph::op::Tanh>(node) ||
ov::is_type<ngraph::op::v0::Gelu>(node) ||
Expand All @@ -135,7 +135,22 @@ bool isSuitableBinaryConvolutionParent(const std::shared_ptr<const Node> &node)
const bool has_only_child = (out.size() == 1) && (out[0].get_target_inputs().size() == 1);
return is_suitable_node && has_only_child;
}
bool isSuitableMiscParent(const std::shared_ptr<const Node> &node) {
int getChannelAxis(const ov::AxisSet &axes, bool keep_dims) {
int channelAxis = 1;
if (!keep_dims) {
for (auto &axis : axes) {
if (axis == 1) {
// channel axis has been reduced and doesn't exist any more
channelAxis = -1;
break;
} else if (axis == 0) {
channelAxis = 0;
}
}
}
return channelAxis;
}
bool isSuitableMiscParent(const std::shared_ptr<const Node> &node, int &channelAxis) {
const bool is_suitable_node = ov::is_type<ngraph::op::v0::MVN>(node) ||
ov::is_type<ngraph::op::v6::MVN>(node) ||
ov::is_type<ngraph::op::v0::NormalizeL2>(node) ||
Expand All @@ -147,6 +162,11 @@ bool isSuitableMiscParent(const std::shared_ptr<const Node> &node) {
ov::is_type<ngraph::op::util::ArithmeticReductionKeepDims>(node) ||
ov::is_type<ngraph::op::util::LogicalReductionKeepDims>(node) ||
ov::is_type<ngraph::opset1::GroupConvolutionBackpropData>(node);
if (const auto reduce = std::dynamic_pointer_cast<const ngraph::op::util::ArithmeticReductionKeepDims>(node)) {
channelAxis = getChannelAxis(reduce->get_reduction_axes(), reduce->get_keep_dims());
} else if (const auto reduce = std::dynamic_pointer_cast<const ngraph::op::util::LogicalReductionKeepDims>(node)) {
channelAxis = getChannelAxis(reduce->get_reduction_axes(), reduce->get_keep_dims());
}
// has a single output, connected to a single child
const auto out = node->outputs();
const bool has_only_child = (out.size() == 1) && (out[0].get_target_inputs().size() == 1);
Expand All @@ -167,7 +187,7 @@ bool isSuitablePoolChild(const std::shared_ptr<const Node> &node) {
const bool has_only_child = (out.size() == 1) && (out[0].get_target_inputs().size() == 1);
return is_suitable_node && has_only_child;
}
bool isSuitableChildForFusingSimple(const std::shared_ptr<const Node> &node, const size_t channelAxis = 1) {
bool isSuitableChildForFusingSimple(const std::shared_ptr<const Node> &node, int channelAxis = 1) {
// Note: Fusing child is allowed to have several users, but that must be the end of the chain
return SupportsFusingWithConvolution_Simple(node, channelAxis) && getNumNonConstInputs(node) == 1;
}
Expand Down Expand Up @@ -205,7 +225,7 @@ bool isSuitableChildForFusingMatMul(const std::shared_ptr<const Node> &node, Nod
// FuseMatMulAndSimpleOperation or FuseFullyConnectedAndSimpleOperation
// Invoke SupportsFusingWithConvolution_Simple directly instead of isSuitableChildForFusingSimple to
// eliminate getNumNonConstInputs() check
size_t fusingAxis;
int fusingAxis;
if (can_be_converted_to_FC)
fusingAxis = matmul_shape.size() == 3 ? 2 : 1;
else
Expand Down Expand Up @@ -300,6 +320,7 @@ void MarkSubgraphOpAsSkipped(const std::shared_ptr<Node> &node) {

bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr<ov::Model> &m) {
RUN_ON_MODEL_SCOPE(SnippetsMarkSkipped);
int channelAxis = 1;
for (auto &node : m->get_ordered_ops()) {
if (ngraph::op::is_constant(node))
continue;
Expand All @@ -313,15 +334,15 @@ bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr<ov::Model> &m) {
} else if (isSuitableBinaryConvolutionParent(node)) {
SetNodeFusingType(node, NodeFusingType::FusedWithBinaryConvolution);
continue;
} else if (isSuitableMiscParent(node)) {
} else if (isSuitableMiscParent(node, channelAxis)) {
SetNodeFusingType(node, NodeFusingType::FusedWithMisc);
continue;
} else if (isSuitableMatMulParent(node)) {
SetNodeFusingType(node, NodeFusingType::FusedWithMatMul);
continue;
}
for (const auto fusingChainType : getContinuableChains(node)) {
if (isSuitableChildForFusingSimple(node)) {
if (isSuitableChildForFusingSimple(node, channelAxis)) {
PropagateIfHasOnlyChild(node, fusingChainType);
} else if (fusingChainType == NodeFusingType::FusedWithConvolution ||
fusingChainType == NodeFusingType::FusedWithBinaryConvolution) {
Expand Down
6 changes: 3 additions & 3 deletions src/plugins/intel_cpu/src/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1130,11 +1130,11 @@ InferenceEngine::Layout Node::getWeightsLayoutByDims(SizeVector dims, bool isGro
}
}

void Node::appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector<MemoryPtr>& postOpsMem) {
void Node::appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector<MemoryPtr>& postOpsMem, const int channelAxis) {
IE_THROW() << "Fusing of " << NameFromType(this->getType()) << " operation is not implemented";
}

void Node::appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector<const void*>& postOpsMem) {
void Node::appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector<const void*>& postOpsMem, const int channelAxis) {
IE_THROW() << "Fusing of " << NameFromType(this->getType()) << " operation is not implemented";
}

Expand Down Expand Up @@ -1267,7 +1267,7 @@ bool Node::canBePerformedAsScaleShift(const Node *parentNode) const {
IE_ASSERT(parentNode);

size_t fusingPort = 0;
const size_t channelAxis = parentNode->getFusingAxis();
const auto channelAxis = parentNode->getFusingAxis();

for (size_t i = 0; i < getParentEdges().size(); i++) {
Node *node = getParentEdgesAtPort(i)[0]->getParent().get();
Expand Down
7 changes: 4 additions & 3 deletions src/plugins/intel_cpu/src/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ class Node {

bool isConstant();

virtual size_t getFusingAxis() const {
// return type int supports return -1 in overloading when channel axis doesn't exist
virtual int getFusingAxis() const {
return 1;
}

Expand Down Expand Up @@ -562,8 +563,8 @@ class Node {
* Seed node should call this routine and pass its post operations list as parameter.
* @param ops List of fused post operations
*/
virtual void appendPostOps(dnnl::post_ops& ops, const VectorDims& postOpDims, std::vector<MemoryPtr>& postOpsMem);
virtual void appendPostOps(dnnl::post_ops& ops, const VectorDims& postOpDims, std::vector<const void*>& postOpsMem);
virtual void appendPostOps(dnnl::post_ops& ops, const VectorDims& postOpDims, std::vector<MemoryPtr>& postOpsMem, const int channelAxis = 1);
virtual void appendPostOps(dnnl::post_ops& ops, const VectorDims& postOpDims, std::vector<const void*>& postOpsMem, const int channelAxis = 1);

virtual void appendBinPostOps(dnnl::post_ops& ops, const VectorDims& postOpDims, std::vector<MemoryPtr>& binaryPostOpsMem);

Expand Down
34 changes: 19 additions & 15 deletions src/plugins/intel_cpu/src/nodes/eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2107,7 +2107,7 @@ void Eltwise::appendMemory(const std::vector<float> &data, MemoryPtr &memPtr, st
}

template <typename T>
void Eltwise::appendPostOpsImpl(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector<T>& postOpsMem) {
void Eltwise::appendPostOpsImpl(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector<T>& postOpsMem, const int channelAxis) {
const std::string errorPrefix = "Appending Eltwise node with name '" + getName() + "' ";

if (getOneDnnAlgorithm() != dnnl::algorithm::undef) {
Expand Down Expand Up @@ -2137,40 +2137,44 @@ void Eltwise::appendPostOpsImpl(dnnl::post_ops& ops, const VectorDims &postOpDim
default: IE_THROW() << errorPrefix << "as post operation is not supported";
}
} else {
const size_t chIdx = postOpDims.size() > 1 ? getFusingAxis() : 0;
int channelSize = 1;
if (channelAxis >= 0) {
const auto chIdx = postOpDims.size() > 1 ? channelAxis : 0;
channelSize = postOpDims[chIdx];
}
// since legacy depthwise post ops mechanism requires broadcasted data we need to reinitilize it in case of changed shape
if (depthwiseData.empty() || depthwiseDataSize != 2 * postOpDims[chIdx]) {
if (depthwiseData.empty() || depthwiseDataSize != 2 * channelSize) {
depthwiseData.clear();
depthwiseMemory.reset();

depthwiseData.insert(depthwiseData.end(), scales.begin(), scales.end());
if (scales.size() == 1) {
depthwiseData.resize(postOpDims[chIdx], depthwiseData.back());
} else if (scales.size() != postOpDims[chIdx]) {
depthwiseData.resize(channelSize, depthwiseData.back());
} else if (scales.size() != channelSize) {
IE_THROW() << errorPrefix << "failed due to scales data size inconsistency";
}
depthwiseData.insert(depthwiseData.end(), shifts.begin(), shifts.end());
if (shifts.empty()) {
// in case of Prelu algorithm scales data is always empty
depthwiseData.resize(2 * postOpDims[chIdx], 0);
depthwiseData.resize(2 * channelSize, 0);
} else if (shifts.size() == 1) {
depthwiseData.resize(2 * postOpDims[chIdx], depthwiseData.back());
} else if (shifts.size() != postOpDims[chIdx]) {
depthwiseData.resize(2 * channelSize, depthwiseData.back());
} else if (shifts.size() != channelSize) {
IE_THROW() << errorPrefix << "failed due to shifts data size inconsistency";
}
depthwiseDataSize = 2 * postOpDims[chIdx];
depthwiseDataSize = 2 * channelSize;

// always align for legacy scale/shift post ops
constexpr int bufferAlignment = 16;
int bufferPaddingSize = rnd_up(postOpDims[chIdx], bufferAlignment) - postOpDims[chIdx];
int bufferPaddingSize = rnd_up(channelSize, bufferAlignment) - channelSize;
depthwiseData.resize(depthwiseDataSize + bufferPaddingSize, 0);
}

if (depthwiseData.empty())
IE_THROW() << errorPrefix << "cannot be performed since buffers are not allocated";

std::array<size_t, 2> offsets = {0};
offsets[1] = offsets[0] + postOpDims[chIdx];
offsets[1] = offsets[0] + channelSize;

/* @todo legacy depthwise post ops are kept for now
* for performance reasons
Expand All @@ -2195,12 +2199,12 @@ void Eltwise::appendPostOpsImpl(dnnl::post_ops& ops, const VectorDims &postOpDim
}
}

void Eltwise::appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector<MemoryPtr>& postOpsMem) {
appendPostOpsImpl(ops, postOpDims, postOpsMem);
void Eltwise::appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector<MemoryPtr>& postOpsMem, const int channelAxis) {
appendPostOpsImpl(ops, postOpDims, postOpsMem, channelAxis);
}

void Eltwise::appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector<const void*>& postOpsMem) {
appendPostOpsImpl(ops, postOpDims, postOpsMem);
void Eltwise::appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector<const void*>& postOpsMem, const int channelAxis) {
appendPostOpsImpl(ops, postOpDims, postOpsMem, channelAxis);
}

void Eltwise::appendBinPostOps(dnnl::post_ops& ops, const VectorDims& postOpDims, std::vector<MemoryPtr>& binaryPostOpsMem) {
Expand Down
6 changes: 3 additions & 3 deletions src/plugins/intel_cpu/src/nodes/eltwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ class Eltwise : public Node {
bool created() const override;
bool canBeInPlace() const override;
bool canFuse(const NodePtr& node) const override;
void appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector<MemoryPtr>& postOpsMem) override;
void appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector<const void*>& postOpsMem) override;
void appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector<MemoryPtr>& postOpsMem, const int channelAxis = 1) override;
void appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector<const void*>& postOpsMem, const int channelAxis = 1) override;
void appendBinPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector<MemoryPtr>& binaryPostOpsMem) override;
void fuseInto(NodePtr& parentNode) override;
InferenceEngine::Precision getRuntimePrecision() const override;
Expand Down Expand Up @@ -172,7 +172,7 @@ class Eltwise : public Node {
size_t getOpInputsNum() const;

template <typename T>
void appendPostOpsImpl(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector<T>& postOpsMem);
void appendPostOpsImpl(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector<T>& postOpsMem, const int channelAxis = 1);

void appendMemory(const std::vector<float> &data, MemoryPtr &memPtr, std::vector<MemoryPtr>& postOpsMem);
void appendMemory(const std::vector<float> &data, MemoryPtr &memPtr, std::vector<const void*>& postOpsMem);
Expand Down
6 changes: 4 additions & 2 deletions src/plugins/intel_cpu/src/nodes/fake_quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1855,11 +1855,13 @@ void FakeQuantize::appendPostOpsImpl(dnnl::post_ops& ops, const VectorDims &post
}
}

void FakeQuantize::appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector<MemoryPtr>& postOpsMem) {
void FakeQuantize::appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector<MemoryPtr>& postOpsMem,
const int channelAxis) {
appendPostOpsImpl(ops, postOpDims, postOpsMem);
}

void FakeQuantize::appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector<const void*>& postOpsMem) {
void FakeQuantize::appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector<const void*>& postOpsMem,
const int channelAxis) {
appendPostOpsImpl(ops, postOpDims, postOpsMem);
}

Expand Down
4 changes: 2 additions & 2 deletions src/plugins/intel_cpu/src/nodes/fake_quantize.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ class FakeQuantize : public Node {
InferenceEngine::Precision getInputPrecision() const { return inputPrecision; }
InferenceEngine::Precision getOutputPrecision() const { return outputPrecision; }

void appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector<MemoryPtr>& postOpsMem) override;
void appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector<const void*>& postOpsMem) override;
void appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector<MemoryPtr>& postOpsMem, const int channelAxis = 1) override;
void appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector<const void*>& postOpsMem, const int channelAxis = 1) override;
void appendBinPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector<MemoryPtr>& binaryPostOpsMem) override;

static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/nodes/fullyconnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ void FullyConnected::setPostOps(dnnl::primitive_attr &attr, const VectorDims &di
auto getBinPostOpShape = [&](){
const size_t binaryShapeRank = getOutputShapeAtPort(0).getRank() == 3 ? 2 : getOutputShapeAtPort(0).getRank();
VectorDims binaryShape(binaryShapeRank, 1);
const size_t channelAxis = getFusingAxis();
const auto channelAxis = getFusingAxis();
// always use 1 as channelAxis for binary Shape, since oneDNN primitive is actually always 2D
binaryShape[1] = dims[channelAxis];

Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/nodes/fullyconnected.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class FullyConnected : public Node {
return false;
}

size_t getFusingAxis() const override {
int getFusingAxis() const override {
return getOutputShapeAtPort(0).getRank() == 3 ? 2 : 1;
}

Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/nodes/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class MatMul : public Node {
return getOriginalInputsNumber();
}

size_t getFusingAxis() const override {
int getFusingAxis() const override {
return getOutputShapeAtPort(0).getRank() - 1;
}

Expand Down
19 changes: 18 additions & 1 deletion src/plugins/intel_cpu/src/nodes/reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2844,7 +2844,7 @@ void Reduce::setPostOps(dnnl::primitive_attr &attr, const VectorDims &postOpDims

auto* eltwiseNode = dynamic_cast<Eltwise *>(node.get());
if (eltwiseNode) {
eltwiseNode->appendPostOps(ops, postOpDims, postOpsDataPtrs);
eltwiseNode->appendPostOps(ops, postOpDims, postOpsDataPtrs, getFusingAxis());
continue;
}
IE_THROW() << "Fusing of " << NameFromType(node->getType()) << " operation to " << NameFromType(this->getType()) << " node is not implemented";
Expand Down Expand Up @@ -2921,6 +2921,23 @@ bool Reduce::canApplyJIT(const Precision &input_prec, const Precision &output_pr
std::find(std::begin(supportedPrecisions), std::end(supportedPrecisions), output_prec) != std::end(supportedPrecisions);
}

int Reduce::getFusingAxis() const {
int channelAxis = 1;
if (!keep_dims) {
for (auto &raw_axis : raw_axes) {
int axis = raw_axis >= 0 ? raw_axis : raw_axis + static_cast<int>(getInputShapeAtPort(REDUCE_DATA).getRank());
if (axis == 1) {
// channel axis has been reduced and doesn't exist any more
channelAxis = -1;
break;
} else if (axis == 0) {
channelAxis = 0;
}
}
}
return channelAxis;
}

bool Reduce::canFuse(const NodePtr& node) const {
Precision input_prec = getOriginalInputPrecisionAtPort(REDUCE_DATA);
Precision output_prec = getOriginalOutputPrecisionAtPort(0);
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/nodes/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class Reduce : public Node {
void execute(dnnl::stream strm) override;
std::vector<VectorDims> shapeInfer() const override;
void executeDynamicImpl(dnnl::stream strm) override;
int getFusingAxis() const override;
bool canFuse(const NodePtr& node) const override;
bool canBeInPlace() const override {
return false;
Expand Down
Loading

0 comments on commit 8886d0f

Please sign in to comment.