This repository has been archived by the owner on Jan 24, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 114
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'develop' into add_new_model
- Loading branch information
Showing
29 changed files
with
1,284 additions
and
191 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,3 +19,4 @@ gen_modules | |
docs/source/cpp | ||
docs/source/doxygen_output | ||
docs/source/tutorials | ||
.vscode* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#include "cinn/frontend/base_builder.h" | ||
|
||
#include <string> | ||
#include <utility> | ||
|
||
#include "cinn/common/common.h" | ||
#include "cinn/common/context.h" | ||
|
||
namespace cinn { | ||
namespace frontend { | ||
|
||
Program BaseBuilder::Build() { | ||
Program program{std::move(instrs_), std::move(inputs_)}; | ||
program.Validate(); | ||
return program; | ||
} | ||
|
||
Placeholder BaseBuilder::CreateInput(const common::Type& type, | ||
const std::vector<int>& shape, | ||
const std::string& id_hint) { | ||
if (!id_hint.empty()) { | ||
CheckVarNameValid(id_hint); | ||
} | ||
std::string id = id_hint.empty() ? common::Context::Global().NewName("placeholder") : id_hint; | ||
|
||
inputs_.emplace_back(id); | ||
auto& var = inputs_.back(); | ||
var->type = type; | ||
var->shape = shape; | ||
return Placeholder(var); | ||
} | ||
|
||
} // namespace frontend | ||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#pragma once | ||
|
||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "cinn/common/type.h" | ||
#include "cinn/frontend/syntax.h" | ||
|
||
namespace cinn { | ||
namespace frontend { | ||
|
||
class BaseBuilder { | ||
public: | ||
explicit BaseBuilder(const std::string& name) : name_(name) {} | ||
|
||
Program Build(); | ||
|
||
Placeholder CreateInput(const common::Type& type, const std::vector<int>& shape, const std::string& id_hint = ""); | ||
|
||
// name of this builder | ||
const std::string& name() { return name_; } | ||
|
||
virtual ~BaseBuilder() {} | ||
|
||
protected: | ||
void AppendInstruction(const Instruction& instr) { instrs_.push_back(instr); } | ||
|
||
std::string name_; | ||
std::vector<Instruction> instrs_; | ||
std::vector<Variable> inputs_; | ||
}; | ||
|
||
} // namespace frontend | ||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
#include "cinn/frontend/net_builder.h" | ||
|
||
#include <string> | ||
#include <unordered_map> | ||
#include <utility> | ||
|
||
#include "cinn/frontend/syntax.h" | ||
|
||
namespace cinn { | ||
namespace frontend { | ||
|
||
Variable NetBuilder::add(const Variable& a, const Variable& b) { | ||
Instruction instr("elementwise_add", {a, b}); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::mul(const Variable& a, const Variable& b, int x_num_col_dims, int y_num_col_dims) { | ||
Instruction instr("mul", {a, b}); | ||
instr.SetAttr("x_num_col_dims", x_num_col_dims); | ||
instr.SetAttr("y_num_col_dims", y_num_col_dims); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::mulbias( | ||
const Variable& a, const Variable& b, const Variable& c, int x_num_col_dims, int y_num_col_dims) { | ||
Instruction instr("mulbias", {a, b, c}); | ||
instr.SetAttr("x_num_col_dims", x_num_col_dims); | ||
instr.SetAttr("y_num_col_dims", y_num_col_dims); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(1); | ||
} | ||
|
||
Variable NetBuilder::elementwise_add(const Variable& a, const Variable& b, int axis) { | ||
Instruction instr("elementwise_add", {a, b}); | ||
instr.SetAttr("axis", axis); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::elementwise_mul(const Variable& a, const Variable& b, int axis) { | ||
Instruction instr("elementwise_mul", {a, b}); | ||
instr.SetAttr("axis", axis); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::relu(const Variable& a) { | ||
Instruction instr("relu", {a}); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::relu6(const Variable& a, float threshold) { | ||
Instruction instr("relu6", {a}); | ||
instr.SetAttr("threshold", threshold); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::conv2d(const Variable& a, | ||
const Variable& b, | ||
const std::vector<int>& strides, | ||
const std::vector<int>& paddings, | ||
const std::vector<int>& dilations, | ||
int groups, | ||
const std::string& data_format, | ||
const std::string& padding_algorithm) { | ||
Instruction instr("conv2d"); | ||
instr.SetInputs({a, b}); | ||
instr.SetAttr("strides", strides); | ||
instr.SetAttr("paddings", paddings); | ||
instr.SetAttr("dilations", dilations); | ||
instr.SetAttr("groups", groups); | ||
instr.SetAttr("data_format", data_format); | ||
instr.SetAttr("padding_algorithm", padding_algorithm); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::depthwise_conv2d(const Variable& a, | ||
const Variable& b, | ||
const std::vector<int>& strides, | ||
const std::vector<int>& paddings, | ||
const std::vector<int>& dilations, | ||
int groups, | ||
const std::string& data_format, | ||
const std::string& padding_algorithm) { | ||
Instruction instr("depthwise_conv2d"); | ||
instr.SetInputs({a, b}); | ||
instr.SetAttr("strides", strides); | ||
instr.SetAttr("paddings", paddings); | ||
instr.SetAttr("dilations", dilations); | ||
instr.SetAttr("groups", groups); | ||
instr.SetAttr("data_format", data_format); | ||
instr.SetAttr("padding_algorithm", padding_algorithm); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::pool2d(const Variable& a, | ||
const std::string& pooling_type, | ||
const std::vector<int>& ksize, | ||
const std::vector<int>& strides, | ||
const std::vector<int>& paddings, | ||
bool ceil_mode, | ||
bool exclusive, | ||
bool global_pooling, | ||
const std::string& data_format, | ||
bool adaptive, | ||
const std::string& padding_algorithm) { | ||
Instruction instr("pool2d"); | ||
instr.SetInputs({a}); | ||
instr.SetAttr("pooling_type", pooling_type); | ||
instr.SetAttr("ksize", ksize); | ||
instr.SetAttr("strides", strides); | ||
instr.SetAttr("paddings", paddings); | ||
instr.SetAttr("ceil_mode", ceil_mode); | ||
instr.SetAttr("exclusive", exclusive); | ||
instr.SetAttr("global_pooling", global_pooling); | ||
instr.SetAttr("data_format", data_format); | ||
instr.SetAttr("adaptive", adaptive); | ||
instr.SetAttr("padding_algorithm", padding_algorithm); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::batchnorm(const Variable& a, | ||
const Variable& scale, | ||
const Variable& bias, | ||
const Variable& mean, | ||
const Variable& variance, | ||
float epsilon, | ||
float momentum, | ||
const std::string& data_layout) { | ||
Instruction instr("batchnorm"); | ||
instr.SetInputs({a, scale, bias, mean, variance}); | ||
instr.SetAttr("epsilon", epsilon); | ||
instr.SetAttr("momentum", momentum); | ||
instr.SetAttr("data_layout", data_layout); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::scale(const Variable& a, float scale, float bias, bool bias_after_scale) { | ||
Instruction instr("scale", {a}); | ||
instr.SetAttr("scale", scale); | ||
instr.SetAttr("bias", bias); | ||
instr.SetAttr("bias_after_scale", bias_after_scale); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::softmax(const Variable& a, int axis, const std::string& data_format) { | ||
Instruction instr("softmax", {a}); | ||
instr.SetAttr("axis", axis); | ||
instr.SetAttr("data_format", data_format); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::sigmoid(const Variable& a) { | ||
Instruction instr("sigmoid", {a}); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::slice(const Variable& a, | ||
const std::vector<int>& axes, | ||
const std::vector<int>& starts, | ||
const std::vector<int>& ends, | ||
const std::vector<int>& infer_flags, | ||
const std::vector<int>& decrease_axis) { | ||
Instruction instr("slice", {a}); | ||
instr.SetAttr("axes", axes); | ||
instr.SetAttr("starts", starts); | ||
instr.SetAttr("ends", ends); | ||
instr.SetAttr("infer_flags", infer_flags); | ||
instr.SetAttr("decrease_axis", decrease_axis); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
Variable NetBuilder::dropout_infer(const Variable& a, float dropout_prob, const std::string& dropout_implementation) { | ||
Instruction instr("dropout_infer", {a}); | ||
instr.SetAttr("dropout_prob", dropout_prob); | ||
instr.SetAttr("dropout_implementation", dropout_implementation); | ||
AppendInstruction(instr); | ||
return instr.GetOutput(0); | ||
} | ||
|
||
} // namespace frontend | ||
} // namespace cinn |
Oops, something went wrong.