Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Merge branch 'develop' into add_new_model
Browse files Browse the repository at this point in the history
  • Loading branch information
wenming2014 authored Sep 27, 2021
2 parents d1f6e55 + bf94356 commit 87833c2
Show file tree
Hide file tree
Showing 29 changed files with 1,284 additions and 191 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ gen_modules
docs/source/cpp
docs/source/doxygen_output
docs/source/tutorials
.vscode*
3 changes: 3 additions & 0 deletions cinn/frontend/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ core_gather_srcs(SRCS
syntax.cc
paddle_model_to_program.cc
interpreter.cc
base_builder.cc
net_builder.cc
)

if(NOT WITH_CUDA)
Expand All @@ -23,5 +25,6 @@ else()
SRCS interpreter_test.cc DEPS cinncore)
endif()

cc_test(test_net_builder SRCS net_builder_test.cc DEPS cinncore)

add_subdirectory(paddle)
34 changes: 34 additions & 0 deletions cinn/frontend/base_builder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#include "cinn/frontend/base_builder.h"

#include <string>
#include <utility>

#include "cinn/common/common.h"
#include "cinn/common/context.h"

namespace cinn {
namespace frontend {

Program BaseBuilder::Build() {
Program program{std::move(instrs_), std::move(inputs_)};
program.Validate();
return program;
}

Placeholder BaseBuilder::CreateInput(const common::Type& type,
const std::vector<int>& shape,
const std::string& id_hint) {
if (!id_hint.empty()) {
CheckVarNameValid(id_hint);
}
std::string id = id_hint.empty() ? common::Context::Global().NewName("placeholder") : id_hint;

inputs_.emplace_back(id);
auto& var = inputs_.back();
var->type = type;
var->shape = shape;
return Placeholder(var);
}

} // namespace frontend
} // namespace cinn
35 changes: 35 additions & 0 deletions cinn/frontend/base_builder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once

#include <string>
#include <utility>
#include <vector>

#include "cinn/common/type.h"
#include "cinn/frontend/syntax.h"

namespace cinn {
namespace frontend {

class BaseBuilder {
public:
explicit BaseBuilder(const std::string& name) : name_(name) {}

Program Build();

Placeholder CreateInput(const common::Type& type, const std::vector<int>& shape, const std::string& id_hint = "");

// name of this builder
const std::string& name() { return name_; }

virtual ~BaseBuilder() {}

protected:
void AppendInstruction(const Instruction& instr) { instrs_.push_back(instr); }

std::string name_;
std::vector<Instruction> instrs_;
std::vector<Variable> inputs_;
};

} // namespace frontend
} // namespace cinn
194 changes: 194 additions & 0 deletions cinn/frontend/net_builder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
#include "cinn/frontend/net_builder.h"

#include <string>
#include <unordered_map>
#include <utility>

#include "cinn/frontend/syntax.h"

namespace cinn {
namespace frontend {

Variable NetBuilder::add(const Variable& a, const Variable& b) {
Instruction instr("elementwise_add", {a, b});
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable NetBuilder::mul(const Variable& a, const Variable& b, int x_num_col_dims, int y_num_col_dims) {
Instruction instr("mul", {a, b});
instr.SetAttr("x_num_col_dims", x_num_col_dims);
instr.SetAttr("y_num_col_dims", y_num_col_dims);
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable NetBuilder::mulbias(
const Variable& a, const Variable& b, const Variable& c, int x_num_col_dims, int y_num_col_dims) {
Instruction instr("mulbias", {a, b, c});
instr.SetAttr("x_num_col_dims", x_num_col_dims);
instr.SetAttr("y_num_col_dims", y_num_col_dims);
AppendInstruction(instr);
return instr.GetOutput(1);
}

Variable NetBuilder::elementwise_add(const Variable& a, const Variable& b, int axis) {
Instruction instr("elementwise_add", {a, b});
instr.SetAttr("axis", axis);
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable NetBuilder::elementwise_mul(const Variable& a, const Variable& b, int axis) {
Instruction instr("elementwise_mul", {a, b});
instr.SetAttr("axis", axis);
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable NetBuilder::relu(const Variable& a) {
Instruction instr("relu", {a});
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable NetBuilder::relu6(const Variable& a, float threshold) {
Instruction instr("relu6", {a});
instr.SetAttr("threshold", threshold);
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable NetBuilder::conv2d(const Variable& a,
const Variable& b,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
int groups,
const std::string& data_format,
const std::string& padding_algorithm) {
Instruction instr("conv2d");
instr.SetInputs({a, b});
instr.SetAttr("strides", strides);
instr.SetAttr("paddings", paddings);
instr.SetAttr("dilations", dilations);
instr.SetAttr("groups", groups);
instr.SetAttr("data_format", data_format);
instr.SetAttr("padding_algorithm", padding_algorithm);
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable NetBuilder::depthwise_conv2d(const Variable& a,
const Variable& b,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
int groups,
const std::string& data_format,
const std::string& padding_algorithm) {
Instruction instr("depthwise_conv2d");
instr.SetInputs({a, b});
instr.SetAttr("strides", strides);
instr.SetAttr("paddings", paddings);
instr.SetAttr("dilations", dilations);
instr.SetAttr("groups", groups);
instr.SetAttr("data_format", data_format);
instr.SetAttr("padding_algorithm", padding_algorithm);
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable NetBuilder::pool2d(const Variable& a,
const std::string& pooling_type,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings,
bool ceil_mode,
bool exclusive,
bool global_pooling,
const std::string& data_format,
bool adaptive,
const std::string& padding_algorithm) {
Instruction instr("pool2d");
instr.SetInputs({a});
instr.SetAttr("pooling_type", pooling_type);
instr.SetAttr("ksize", ksize);
instr.SetAttr("strides", strides);
instr.SetAttr("paddings", paddings);
instr.SetAttr("ceil_mode", ceil_mode);
instr.SetAttr("exclusive", exclusive);
instr.SetAttr("global_pooling", global_pooling);
instr.SetAttr("data_format", data_format);
instr.SetAttr("adaptive", adaptive);
instr.SetAttr("padding_algorithm", padding_algorithm);
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable NetBuilder::batchnorm(const Variable& a,
const Variable& scale,
const Variable& bias,
const Variable& mean,
const Variable& variance,
float epsilon,
float momentum,
const std::string& data_layout) {
Instruction instr("batchnorm");
instr.SetInputs({a, scale, bias, mean, variance});
instr.SetAttr("epsilon", epsilon);
instr.SetAttr("momentum", momentum);
instr.SetAttr("data_layout", data_layout);
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable NetBuilder::scale(const Variable& a, float scale, float bias, bool bias_after_scale) {
Instruction instr("scale", {a});
instr.SetAttr("scale", scale);
instr.SetAttr("bias", bias);
instr.SetAttr("bias_after_scale", bias_after_scale);
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable NetBuilder::softmax(const Variable& a, int axis, const std::string& data_format) {
Instruction instr("softmax", {a});
instr.SetAttr("axis", axis);
instr.SetAttr("data_format", data_format);
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable NetBuilder::sigmoid(const Variable& a) {
Instruction instr("sigmoid", {a});
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable NetBuilder::slice(const Variable& a,
const std::vector<int>& axes,
const std::vector<int>& starts,
const std::vector<int>& ends,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis) {
Instruction instr("slice", {a});
instr.SetAttr("axes", axes);
instr.SetAttr("starts", starts);
instr.SetAttr("ends", ends);
instr.SetAttr("infer_flags", infer_flags);
instr.SetAttr("decrease_axis", decrease_axis);
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable NetBuilder::dropout_infer(const Variable& a, float dropout_prob, const std::string& dropout_implementation) {
Instruction instr("dropout_infer", {a});
instr.SetAttr("dropout_prob", dropout_prob);
instr.SetAttr("dropout_implementation", dropout_implementation);
AppendInstruction(instr);
return instr.GetOutput(0);
}

} // namespace frontend
} // namespace cinn
Loading

0 comments on commit 87833c2

Please sign in to comment.