From 3c46e58a5d6d9a61994b4f3410cf97118ede1c7d Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Mon, 29 Mar 2021 21:20:37 -0700 Subject: [PATCH 1/3] Set up CNN accelerator codegen skeleton --- .gitignore | 2 + CMakeLists.txt | 1 + cmake/modules/contrib/ILACNN.cmake | 9 ++ python/tvm/relay/op/contrib/ilacnn.py | 63 ++++++++ .../backend/contrib/ilacnn/ilacnn_codegen.cc | 98 +++++++++++ src/runtime/contrib/ilacnn/ilacnn_runtime.cc | 152 ++++++++++++++++++ .../python/byo3la/end_to_end_efficientnet.py | 57 +++++++ tests/python/byo3la/match_conv2d.py | 58 +++++++ 8 files changed, 440 insertions(+) create mode 100644 cmake/modules/contrib/ILACNN.cmake create mode 100644 python/tvm/relay/op/contrib/ilacnn.py create mode 100644 src/relay/backend/contrib/ilacnn/ilacnn_codegen.cc create mode 100644 src/runtime/contrib/ilacnn/ilacnn_runtime.cc create mode 100644 tests/python/byo3la/end_to_end_efficientnet.py create mode 100644 tests/python/byo3la/match_conv2d.py diff --git a/.gitignore b/.gitignore index cdcf6780a..7186387dc 100644 --- a/.gitignore +++ b/.gitignore @@ -233,3 +233,5 @@ conda/pkg # nix files .envrc *.nix + +tests/python/byo3la/EfficientNet/* \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 386ca5d43..e52d2dc2a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -341,6 +341,7 @@ include(cmake/modules/contrib/BLAS.cmake) include(cmake/modules/contrib/CODEGENC.cmake) include(cmake/modules/contrib/DNNL.cmake) include(cmake/modules/contrib/ILAVTA.cmake) +include(cmake/modules/contrib/ILACNN.cmake) include(cmake/modules/contrib/Random.cmake) include(cmake/modules/contrib/Posit.cmake) include(cmake/modules/contrib/MicroStandaloneRuntime.cmake) diff --git a/cmake/modules/contrib/ILACNN.cmake b/cmake/modules/contrib/ILACNN.cmake new file mode 100644 index 000000000..725acd306 --- /dev/null +++ b/cmake/modules/contrib/ILACNN.cmake @@ -0,0 +1,9 @@ +if(USE_ILACNN_CODEGEN STREQUAL "ON") + add_definitions(-DUSE_ILACNN_RUNTIME=1) + file(GLOB ILACNN_RELAY_CONTRIB_SRC src/relay/backend/contrib/ilacnn/*.cc) + list(APPEND COMPILER_SRCS ${ILACNN_RELAY_CONTRIB_SRC}) + list(APPEND COMPILER_SRCS ${JSON_RELAY_CONTRIB_SRC}) + + file(GLOB ILACNN_CONTRIB_SRC src/runtime/contrib/ilacnn/ilacnn_runtime.cc) + list(APPEND RUNTIME_SRCS ${ILACNN_CONTRIB_SRC}) +endif() \ No newline at end of file diff --git a/python/tvm/relay/op/contrib/ilacnn.py b/python/tvm/relay/op/contrib/ilacnn.py new file mode 100644 index 000000000..6096d48e3 --- /dev/null +++ b/python/tvm/relay/op/contrib/ilacnn.py @@ -0,0 +1,63 @@ +""" +Python bindings and helpers for ILACNN codegen, +note that the accelerator does not do padding for Conv2D's, +so you should use remove_padding on the main function before pattern matching +(this converts conv2d's with padding to conv2d(pad(data))) +""" +import tvm +from tvm import relay +from tvm.relay.expr_functor import ExprMutator +import tvm.ir +from ...dataflow_pattern import wildcard, is_op +from .register import register_pattern_table + +def remove_padding(func): + """ + The CNN accelerator cannot handle padding in conv2d, + so this will rewrite all conv2d's with padding into + conv2d on a separately padded tensor (i.e., handle padding in the host) + """ + class PaddingRemover(ExprMutator): + def visit_call(self, call): + if call.attrs is None: + return super().visit_call(call) + attrs = call.attrs + if not isinstance(attrs, relay.op.op_attrs.Conv2DAttrs): + return super().visit_call(call) + padding = attrs.padding + # nothing to do if no padding + if all(map(lambda d: d == 0, padding)): + return super().visit_call(call) + + # otherwise rewrite as a padded call + data = self.visit(call.args[0]) + weight = self.visit(call.args[1]) + + # relay.nn.pad expects padding in the format of (x_left, x_right), (y_top, y_bottom) + data_layout = attrs.data_layout + # we are only padding the H and W dimensions + pad_dims = [(0, 0), (0, 0), (padding[0], padding[2]), (padding[1], padding[3])] + if data_layout == "NHWC": + pad_dims = [(0, 0), (padding[0], padding[2]), (padding[1], padding[3]), (0, 0)] + + padded_data = relay.nn.pad(data, pad_dims) + return relay.nn.conv2d(padded_data, weight, + strides=attrs.strides, + padding=0, + dilation=attrs.dilation, + groups=attrs.groups, + channels=attrs.channels, + kernel_size=attrs.kernel_size, + data_layout=attrs.data_layout, + kernel_layout=attrs.kernel_layout, + out_layout=attrs.out_layout, + out_dtype=attrs.out_dtype) + + remover = PaddingRemover() + return remover.visit(func) + + +@register_pattern_table("ilacnn") +def pattern_table(): + conv2d_pattern = ("ilacnn.conv2d", is_op('nn.conv2d')(wildcard(), wildcard())) + return [conv2d_pattern] diff --git a/src/relay/backend/contrib/ilacnn/ilacnn_codegen.cc b/src/relay/backend/contrib/ilacnn/ilacnn_codegen.cc new file mode 100644 index 000000000..c79ff0d3c --- /dev/null +++ b/src/relay/backend/contrib/ilacnn/ilacnn_codegen.cc @@ -0,0 +1,98 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../../utils.h" + +#include "../../../../runtime/contrib/json/json_node.h" +#include "../codegen_json/codegen_json.h" + +namespace tvm { +namespace relay { +namespace contrib { + +using namespace backend; + +class IlaCNNJSONSerializer : public backend::contrib::JSONSerializer { + using JSONGraphNode = tvm::runtime::json::JSONGraphNode; + using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; + + public: + IlaCNNJSONSerializer(const std::string& symbol, const Expr& expr) : JSONSerializer(symbol, expr) {} + + std::vector VisitExpr_(const CallNode* cn) override { + std::string name; + + if (const auto* op_node = cn->op.as()) { + name = op_node->name; + } else if (const auto* fn = cn->op.as()) { + auto comp = fn->GetAttr(attr::kComposite); + CHECK(comp.defined()) + << "JSON runtime only supports composite functions."; + name = comp.value(); + + if (name != "ilacnn.conv2d") { + LOG(FATAL) << "Unrecognized pattern: " << name; + } + } else { + LOG(FATAL) << "IlaCNN runtime does not support calls to " + << cn->op->GetTypeKey(); + } + LOG(INFO) << "[Pattern Matching] Find annotated: " << name; + + std::vector inputs; + for (const auto& arg : cn->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + auto node = std::make_shared(name, /* name_ */ + "kernel", /* op_type_ */ + inputs, 1 /* num_outputs_ */); + + // Note: conv2d has a lot of attrs that are relevant for codegen, + // especially the stride size. + // However, the pattern matcher will produce patterns in the form of + // fn(Compiler="ilacnn") { + // fn(Composuite="ilacnn.conv2d") { nn.conv2d(...) } + // } + // so we need to reach inside the inner function to get the conv2d attrs (weird, yeah); + // see codegen_json.h:SetCallNodeAttribute + + tvm::relay::backend::contrib::OpAttrExtractor extractor(node); + auto inner_func = Downcast(cn->op); + auto inner_call = Downcast(inner_func->body); + const Object* inner_call_attr = inner_call->attrs.get(); + extractor.Extract(const_cast(inner_call_attr)); + return AddNode(node, GetRef(cn)); + } +}; // class IlaCNNJSONSerializer + +runtime::Module IlaCNNCompiler(const ObjectRef& ref) { + CHECK(ref->IsInstance()); + auto func = Downcast(ref); + auto func_name = GetExtSymbol(func); + + IlaCNNJSONSerializer serializer(func_name, func); + serializer.serialize(); + std::string graph_json = serializer.GetJSON(); + auto params = serializer.GetParams(); + + const auto* pf = runtime::Registry::Get("runtime.IlaCNNRuntimeCreate"); + CHECK(pf != nullptr) << "Cannot find IlaCNN runtime module to create"; + auto mod = (*pf)(func_name, graph_json, params); + return mod; +} + +TVM_REGISTER_GLOBAL("relay.ext.ilacnn").set_body_typed(IlaCNNCompiler); + +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/runtime/contrib/ilacnn/ilacnn_runtime.cc b/src/runtime/contrib/ilacnn/ilacnn_runtime.cc new file mode 100644 index 000000000..1f1ccd6c7 --- /dev/null +++ b/src/runtime/contrib/ilacnn/ilacnn_runtime.cc @@ -0,0 +1,152 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../json/json_node.h" +#include "../json/json_runtime.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace tvm::runtime; +using namespace tvm::runtime::json; + +class IlaCNNRuntime : public JSONRuntimeBase { + public: + IlaCNNRuntime(const std::string& symbol_name, const std::string& graph_json, + const Array const_names) + : JSONRuntimeBase(symbol_name, graph_json, const_names) {} + + const char* type_key() const { return "ilacnn"; } + + void Init(const Array& consts) override { + // CHECK(consts.size() == 0); + } + + void Run() override { + CHECK(symbol_name_.substr(0, 6) == "ilacnn") << symbol_name_; + LOG(INFO) << "[Runtime] entering " << symbol_name_ << " runtime"; + + if (outputs_.size() == 1 && input_nodes_.size() == 2 && + nodes_[outputs_[0].id_].GetOpName() == "ilacnn.conv2d") { + auto call_node = nodes_[outputs_[0].id_]; + + // data + auto eid_data = EntryID(input_nodes_[0], 0); + auto& data_info = data_entry_[eid_data]; + CHECK(data_info->ndim == 4); + std::cout << "Data shape: (" + << data_info->shape[0] << ", " + << data_info->shape[1] << ", " + << data_info->shape[2] << ", " + << data_info->shape[3] + << ")" << std::endl; + + // weight + auto eid_weight = EntryID(input_nodes_[1], 0); + auto& weight_info = data_entry_[eid_weight]; + CHECK(weight_info->ndim == 4); + std::cout << "Weight shape: (" + << weight_info->shape[0] << ", " + << weight_info->shape[1] << ", " + << weight_info->shape[2] << ", " + << weight_info->shape[3] + << ")" << std::endl; + + // output + auto eid_o = outputs_[0].id_; + auto out_info = data_entry_[eid_o]; + CHECK(out_info->ndim == 4); + std::cout << "Output shape: (" + << out_info->shape[0] << ", " + << out_info->shape[1] << ", " + << out_info->shape[2] << ", " + << out_info->shape[3] + << ")" << std::endl; + + // attributes + auto strides = call_node.GetAttr>("strides"); + auto padding = call_node.GetAttr>("padding"); + auto data_layout = call_node.GetAttr>("data_layout"); + auto kernel_layout = call_node.GetAttr>("kernel_layout"); + // etc + + std::cout << "Strides: " << "("; + for (const auto dim : strides) { + std::cout << dim << ","; + } + std::cout << ")" << std::endl; + std::cout << "Padding: " << "("; + for (const auto dim : padding) { + std::cout << dim << ","; + } + std::cout << ")" << std::endl; + std::cout << "Data layout: " << data_layout[0] << std::endl; + std::cout << "Kernel layout: " << kernel_layout[0] << std::endl; + + // TODO: Instantiate and call driver + } else { + LOG(FATAL) << "Unknown pattern " << symbol_name_; + } + LOG(INFO) << "[Runtime] exit " << symbol_name_ << " runtime, resume host"; + } + + void dump_data(float* data_ptr, unsigned long& size, std::string path) { + std::ofstream fout; + std::stringstream ss; + fout.open(path, std::ios::out | std::ios::trunc); + for (auto i = 0; i < size; ++i) { + ss << data_ptr[i] << '\n'; + } + fout << ss.rdbuf(); + fout.close(); + } + + void retrieve_result(float* data_ptr, unsigned long& size, std::string path) { + // retrieve flexnlp results + std::ifstream fin; + fin.open(path, std::ios::in); + std::string float_str; + unsigned long cntr = 0; + + while(std::getline(fin, float_str)) { + if (cntr >= size) { + LOG(FATAL) << "wrong number of elements in the result tensor"; + } + data_ptr[cntr] = std::stof(float_str); + ++cntr; + } + } + + protected: + private: +}; // namespace runtime + +runtime::Module IlaCNNRuntimeCreate(String symbol_name, String graph_json, + const Array& const_names) { + auto n = make_object(symbol_name, graph_json, const_names); + return runtime::Module(n); +} + +TVM_REGISTER_GLOBAL("runtime.IlaCNNRuntimeCreate") + .set_body_typed(IlaCNNRuntimeCreate); + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_ilacnn") + .set_body_typed(JSONRuntimeBase::LoadFromBinary); + +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/tests/python/byo3la/end_to_end_efficientnet.py b/tests/python/byo3la/end_to_end_efficientnet.py new file mode 100644 index 000000000..4e8fcfa86 --- /dev/null +++ b/tests/python/byo3la/end_to_end_efficientnet.py @@ -0,0 +1,57 @@ +""" +Clones in an MxNet EfficientNet implementation, imports to TVM, +and runs via ILACNN codegen +""" +import os +import subprocess + +import numpy as np + +import tvm +from tvm import relay +from tvm import runtime +from tvm.relay.op.contrib import ilacnn + +TEST_DIR = os.path.dirname(os.path.abspath(__file__)) +ENET_DIR = os.path.join(TEST_DIR, "EfficientNet") +PARAMS_FILE = os.path.join(ENET_DIR, "0.3358-imagenet-efficientnet-b0-47-best.params") + +def efficientnet_present(): + return os.path.exists(ENET_DIR) and os.path.exists(PARAMS_FILE) + + +def pull_efficientnet(): + subprocess.run(["rm", "-rf", ENET_DIR]) + subprocess.run(["git", "clone", "https://github.com/mnikitin/EfficientNet.git"], cwd=TEST_DIR) + subprocess.run(["wget", "https://www.dropbox.com/s/l2ehu85vmmj3w5w/0.3358-imagenet-efficientnet-b0-47-best.params"], cwd=ENET_DIR) + + +def main(): + if not efficientnet_present(): + pull_efficientnet() + from EfficientNet.efficientnet_model import get_efficientnet + + enet, _ = get_efficientnet("efficientnet-b0") + enet.load_parameters(PARAMS_FILE) + mod, params = relay.frontend.from_mxnet(enet, {"data": (1, 3, 224, 224)}) + + params["data"] = tvm.nd.array(np.random.rand(1, 3, 224, 224).astype("float32")) + args = [params[var.name_hint] for var in mod["main"].params] + mod["main"] = ilacnn.remove_padding(mod["main"]) + + pattern_table = ilacnn.pattern_table() + mod = tvm.relay.transform.MergeComposite(pattern_table)(mod) + mod = tvm.relay.transform.AnnotateTarget(["ilacnn"])(mod) + mod = tvm.relay.transform.PartitionGraph()(mod) + + with tvm.transform.PassContext(opt_level=3): + device = tvm.cpu() + target = "llvm" + exe = relay.vm.compile(mod, target) + vm = runtime.vm.VirtualMachine(exe, device) + + ret = vm.invoke("main", *args) + + +if __name__ == "__main__": + main() diff --git a/tests/python/byo3la/match_conv2d.py b/tests/python/byo3la/match_conv2d.py new file mode 100644 index 000000000..eefb1d094 --- /dev/null +++ b/tests/python/byo3la/match_conv2d.py @@ -0,0 +1,58 @@ +import numpy as np +import tvm +from tvm import relay +from tvm import runtime +from tvm.relay.op.contrib import ilacnn + +# just some simple smoke tests +def test_conv2d_unpadded(): + x = relay.Var("x", type_annotation=relay.TensorType((1, 3, 224, 224))) + y = relay.Var("y", type_annotation=relay.TensorType((3, 3, 3, 3))) + conv_func = relay.Function([x, y], relay.nn.conv2d(x, y)) + mod = tvm.IRModule() + mod["main"] = conv_func + + pattern_table = ilacnn.pattern_table() + mod = tvm.relay.transform.MergeComposite(pattern_table)(mod) + mod = tvm.relay.transform.AnnotateTarget(["ilacnn"])(mod) + mod = tvm.relay.transform.PartitionGraph()(mod) + print(mod) + + with tvm.transform.PassContext(opt_level=3): + device = tvm.cpu() + target = "llvm" + exe = relay.vm.compile(mod, target) + vm = runtime.vm.VirtualMachine(exe, device) + + args = [np.random.rand(1, 3, 224, 224).astype("float32"), + np.random.rand(3, 3, 3, 3).astype("float32")] + ret = vm.invoke("main", *args) + + +def test_conv2d_padded(): + x = relay.Var("x", type_annotation=relay.TensorType((1, 3, 220, 218))) + y = relay.Var("y", type_annotation=relay.TensorType((3, 3, 3, 3))) + conv_func = relay.Function([x, y], relay.nn.conv2d(x, y, padding=(2, 3))) + mod = tvm.IRModule() + mod["main"] = ilacnn.remove_padding(conv_func) + + pattern_table = ilacnn.pattern_table() + mod = tvm.relay.transform.MergeComposite(pattern_table)(mod) + mod = tvm.relay.transform.AnnotateTarget(["ilacnn"])(mod) + mod = tvm.relay.transform.PartitionGraph()(mod) + print(mod) + + with tvm.transform.PassContext(opt_level=3): + device = tvm.cpu() + target = "llvm" + exe = relay.vm.compile(mod, target) + vm = runtime.vm.VirtualMachine(exe, device) + + args = [np.random.rand(1, 3, 220, 218).astype("float32"), + np.random.rand(3, 3, 3, 3).astype("float32")] + ret = vm.invoke("main", *args) + + +if __name__ == "__main__": + test_conv2d_unpadded() + test_conv2d_padded() From 8b40ed4170a11976c5d337ee404ed94b6fa6bf23 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Mon, 29 Mar 2021 21:34:13 -0700 Subject: [PATCH 2/3] Just include the model in the repo because it's not that big anyway --- .gitignore | 2 +- tests/python/byo3la/EfficientNet/README.md | 72 +++++++++ .../byo3la/EfficientNet/efficientnet_model.py | 147 ++++++++++++++++++ .../python/byo3la/end_to_end_efficientnet.py | 15 +- 4 files changed, 227 insertions(+), 9 deletions(-) create mode 100644 tests/python/byo3la/EfficientNet/README.md create mode 100644 tests/python/byo3la/EfficientNet/efficientnet_model.py diff --git a/.gitignore b/.gitignore index 7186387dc..071041777 100644 --- a/.gitignore +++ b/.gitignore @@ -234,4 +234,4 @@ conda/pkg .envrc *.nix -tests/python/byo3la/EfficientNet/* \ No newline at end of file +tests/python/byo3la/*.params \ No newline at end of file diff --git a/tests/python/byo3la/EfficientNet/README.md b/tests/python/byo3la/EfficientNet/README.md new file mode 100644 index 000000000..fb6c3ff9a --- /dev/null +++ b/tests/python/byo3la/EfficientNet/README.md @@ -0,0 +1,72 @@ +Note: Taken from https://github.com/mnikitin/EfficientNet.git + +# EfficientNet-Gluon +[EfficientNet](https://arxiv.org/abs/1905.11946) Gluon implementation + +## ImageNet experiments + +### Requirements +Python 3.7 or later with packages: +- `mxnet >= 1.5.0` +- `gluoncv >= 0.6.0` +- `nvidia-dali >= 0.19.0` + +### Usage +#### Prepare ImageNet dataset +1. Download and extract dataset following this tutorial:
+https://gluon-cv.mxnet.io/build/examples_datasets/imagenet.html +2. Create mxnet-record files following this turorial:
+https://gluon-cv.mxnet.io/build/examples_datasets/recordio.html#imagerecord-file-for-imagenet + +#### Clone this repo +``` +git clone https://github.com/mnikitin/EfficientNet.git +cd EfficientNet/train_imagenet +``` + +#### Train your model +Example of training *efficientnet-b0* with *nvidia-dali data loader* using 4 gpus: +``` +IMAGENET_RECORD_ROOT='path/to/imagenet/record/files' +MODEL='efficientnet-b0' +python3 train_dali.py --rec-train $IMAGENET_RECORD_ROOT/train --rec-val $IMAGENET_RECORD_ROOT/val --input-size 224 --batch-size 64 --num-gpus 4 --num-epochs 80 --lr 0.1 --lr-decay-epoch 40,60 --save-dir params-$MODEL --logging-file params-$MODEL/log.txt --save-frequency 5 --mode hybrid --model $MODEL +``` + +### Results +Code in this repo was used to train *efficientnet-b0* and *efficientnet-lite0* models.
+Pretrained params are avaliable (18.8 mb in total = 13.7 mb for *extractor* + 5.1 mb for *classifier*). + + + + + + + + + + + + + + + + + + + + +
err-top1err-top5pretrained params
efficientnet-b00.3358420.128043dropbox link
efficientnet-lite00.3053160.106322dropbox link
+ +**Note** that due to limited computational resources obtained results are worse than in the original paper.
+Moreover, *efficientnet-lite0* was trained using more gpus and bigger batch size, so in spite of simpler architecture (relu6 instead of swish) its results are better than for *efficientnet-b0* model.
+Anyway, I believe provided pretrained params can serve as a good initialization for your task. + +That's how *efficientnet-b0* and *efficientnet-lite0* were trained exactly:
+``` +MODEL='efficientnet-b0' +python3 train_dali.py --rec-train $IMAGENET_RECORD_ROOT/train --rec-val $IMAGENET_RECORD_ROOT/val --input-size 224 --batch-size 56 --num-gpus 4 --num-epochs 50 --lr 0.1 --lr-decay-epoch 20,30,40 --save-dir params-$MODEL --logging-file params-$MODEL/log.txt --save-frequency 5 --mode hybrid --model $MODEL +``` +``` +MODEL='efficientnet-lite0' +python3 train_dali.py --rec-train $IMAGENET_RECORD_ROOT/train --rec-val $IMAGENET_RECORD_ROOT/val --input-size 224 --batch-size 72 --num-gpus 6 --num-epochs 60 --lr 0.1 --lr-decay-epoch 20,35,50 --save-dir params-$MODEL --logging-file params-$MODEL/log.txt --save-frequency 5 --mode hybrid --model $MODEL +``` diff --git a/tests/python/byo3la/EfficientNet/efficientnet_model.py b/tests/python/byo3la/EfficientNet/efficientnet_model.py new file mode 100644 index 000000000..72272d077 --- /dev/null +++ b/tests/python/byo3la/EfficientNet/efficientnet_model.py @@ -0,0 +1,147 @@ +from mxnet.gluon.block import HybridBlock +from mxnet.gluon import nn +from math import ceil + + +class ReLU6(nn.HybridBlock): + def __init__(self, **kwargs): + super(ReLU6, self).__init__(**kwargs) + + def hybrid_forward(self, F, x): + return F.clip(x, 0, 6, name="relu6") + + +def _add_conv(out, channels=1, kernel=1, stride=1, pad=0, + num_group=1, active=True, lite=False): + out.add(nn.Conv2D(channels, kernel, stride, pad, groups=num_group, use_bias=False)) + out.add(nn.BatchNorm(scale=True, momentum=0.99, epsilon=1e-3)) + if active: + if lite: + out.add(ReLU6()) + else: + out.add(nn.Swish()) + + +class MBConv(nn.HybridBlock): + def __init__(self, in_channels, channels, t, kernel, stride, lite, **kwargs): + super(MBConv, self).__init__(**kwargs) + self.use_shortcut = stride == 1 and in_channels == channels + with self.name_scope(): + self.out = nn.HybridSequential() + _add_conv(self.out, in_channels * t, active=True, lite=lite) + _add_conv(self.out, in_channels * t, kernel=kernel, stride=stride, + pad=int((kernel-1)/2), num_group=in_channels * t, + active=True, lite=lite) + _add_conv(self.out, channels, active=False, lite=lite) + + def hybrid_forward(self, F, x): + out = self.out(x) + if self.use_shortcut: + out = F.elemwise_add(out, x) + return out + + +class EfficientNet(nn.HybridBlock): + r""" + Parameters + ---------- + alpha : float, default 1.0 + The depth multiplier for controling the model size. The actual number of layers on each channel_size level + is equal to the original number of layers multiplied by alpha. + beta : float, default 1.0 + The width multiplier for controling the model size. The actual number of channels + is equal to the original channel size multiplied by beta. + dropout_rate : float, default 0.0 + Dropout probability for the final features layer. + classes : int, default 1000 + Number of classes for the output layer. + """ + + def __init__(self, alpha=1.0, beta=1.0, lite=False, + dropout_rate=0.0, classes=1000, **kwargs): + super(EfficientNet, self).__init__(**kwargs) + with self.name_scope(): + self.features = nn.HybridSequential(prefix='features_') + with self.features.name_scope(): + # stem conv + channels = 32 if lite else int(32 * beta) + _add_conv(self.features, channels, kernel=3, stride=2, pad=1, + active=True, lite=lite) + + # base model settings + repeats = [1, 2, 2, 3, 3, 4, 1] + channels_num = [16, 24, 40, 80, 112, 192, 320] + kernels_num = [3, 3, 5, 3, 5, 5, 3] + t_num = [1, 6, 6, 6, 6, 6, 6] + strides_first = [1, 2, 2, 1, 2, 2, 1] + + # determine params of MBConv layers + in_channels_group = [] + for rep, ch_num in zip([1] + repeats[:-1], [32] + channels_num[:-1]): + in_channels_group += [int(ch_num * beta)] * int(ceil(alpha * rep)) + channels_group, kernels, ts, strides = [], [], [], [] + for rep, ch, kernel, t, s in zip(repeats, channels_num, kernels_num, t_num, strides_first): + rep = int(ceil(alpha * rep)) + channels_group += [int(ch * beta)] * rep + kernels += [kernel] * rep + ts += [t] * rep + strides += [s] + [1] * (rep - 1) + + # add MBConv layers + for in_c, c, t, k, s in zip(in_channels_group, channels_group, ts, kernels, strides): + self.features.add(MBConv(in_channels=in_c, channels=c, t=t, kernel=k, + stride=s, lite=lite)) + + # head layers + last_channels = int(1280 * beta) if not lite and beta > 1.0 else 1280 + _add_conv(self.features, last_channels, active=True, lite=lite) + self.features.add(nn.GlobalAvgPool2D()) + + # features dropout + self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0.0 else None + + # output layer + self.output = nn.HybridSequential(prefix='output_') + with self.output.name_scope(): + self.output.add( + nn.Conv2D(classes, 1, use_bias=False, prefix='pred_'), + nn.Flatten() + ) + + def hybrid_forward(self, F, x): + x = self.features(x) + if self.dropout: + x = self.dropout(x) + x = self.output(x) + return x + + +def get_efficientnet(model_name, num_classes=1000): + params_dict = { # (width_coefficient, depth_coefficient, input_resolution, dropout_rate) + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5) + } + width_coeff, depth_coeff, input_resolution, dropout_rate = params_dict[model_name] + model = EfficientNet(alpha=depth_coeff, beta=width_coeff, lite=False, + dropout_rate=dropout_rate, classes=num_classes) + return model, input_resolution + + +def get_efficientnet_lite(model_name, num_classes=1000): + params_dict = { # (width_coefficient, depth_coefficient, input_resolution, dropout_rate) + 'efficientnet-lite0': (1.0, 1.0, 224, 0.2), + 'efficientnet-lite1': (1.0, 1.1, 240, 0.2), + 'efficientnet-lite2': (1.1, 1.2, 260, 0.3), + 'efficientnet-lite3': (1.2, 1.4, 280, 0.3), + 'efficientnet-lite4': (1.4, 1.8, 300, 0.3) + } + width_coeff, depth_coeff, input_resolution, dropout_rate = params_dict[model_name] + model = EfficientNet(alpha=depth_coeff, beta=width_coeff, lite=True, + dropout_rate=dropout_rate, classes=num_classes) + return model, input_resolution diff --git a/tests/python/byo3la/end_to_end_efficientnet.py b/tests/python/byo3la/end_to_end_efficientnet.py index 4e8fcfa86..d99613dd7 100644 --- a/tests/python/byo3la/end_to_end_efficientnet.py +++ b/tests/python/byo3la/end_to_end_efficientnet.py @@ -12,24 +12,23 @@ from tvm import runtime from tvm.relay.op.contrib import ilacnn +from EfficientNet.efficientnet_model import get_efficientnet + TEST_DIR = os.path.dirname(os.path.abspath(__file__)) ENET_DIR = os.path.join(TEST_DIR, "EfficientNet") PARAMS_FILE = os.path.join(ENET_DIR, "0.3358-imagenet-efficientnet-b0-47-best.params") -def efficientnet_present(): - return os.path.exists(ENET_DIR) and os.path.exists(PARAMS_FILE) +def data_present(): + return os.path.exists(PARAMS_FILE) -def pull_efficientnet(): - subprocess.run(["rm", "-rf", ENET_DIR]) - subprocess.run(["git", "clone", "https://github.com/mnikitin/EfficientNet.git"], cwd=TEST_DIR) +def get_data(): subprocess.run(["wget", "https://www.dropbox.com/s/l2ehu85vmmj3w5w/0.3358-imagenet-efficientnet-b0-47-best.params"], cwd=ENET_DIR) def main(): - if not efficientnet_present(): - pull_efficientnet() - from EfficientNet.efficientnet_model import get_efficientnet + if not data_present(): + get_data() enet, _ = get_efficientnet("efficientnet-b0") enet.load_parameters(PARAMS_FILE) From 2ea98c6704cd7343d202de39bd8f82356591bab6 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Tue, 30 Mar 2021 20:53:49 -0700 Subject: [PATCH 3/3] Add utilities for getting counts of operators --- python/tvm/relay/testing/__init__.py | 2 + python/tvm/relay/testing/op_summary.py | 72 +++++++++++++++++++ tests/python/relay/test_op_summary.py | 98 ++++++++++++++++++++++++++ 3 files changed, 172 insertions(+) create mode 100644 python/tvm/relay/testing/op_summary.py create mode 100644 tests/python/relay/test_op_summary.py diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index d5c531579..79ccaa5b2 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -50,6 +50,8 @@ # these are just for testing from .exact_matcher import deduplicate_vars, check_compiler_call +from .op_summary import count_all_ops, count_all_overloads, count_all_ops_in_overloads + def run_opt_pass(expr, opt_pass, import_prelude=False): assert isinstance(opt_pass, tvm.transform.Pass) mod = tvm.IRModule.from_expr(expr) diff --git a/python/tvm/relay/testing/op_summary.py b/python/tvm/relay/testing/op_summary.py new file mode 100644 index 000000000..93bb5f46a --- /dev/null +++ b/python/tvm/relay/testing/op_summary.py @@ -0,0 +1,72 @@ +""" +Utility functions for counting the number of operators +and BYOC overloads in modules. +""" +import tvm +from tvm import relay +from tvm.relay.expr_functor import ExprVisitor + +def is_overload(func): + if func.attrs is None: + return False + return "Compiler" in func.attrs + + +def get_count_expr(counter_class, expr): + counter = counter_class() + counter.visit(expr) + return counter.count + + +def get_count_mod(counter_class, mod): + total_count = 0 + for gv in mod.get_global_vars(): + total_count += get_count_expr(counter_class, mod[gv]) + return total_count + + +class Counter(ExprVisitor): + def __init__(self): + super().__init__() + self.count = 0 + + def eligible(self, expr): + raise NotImplementedError() + + def increment(self, expr): + return 1 + + def visit(self, expr): + if self.eligible(expr): + self.count += self.increment(expr) + super().visit(expr) + + +class OpCounter(Counter): + def eligible(self, expr): + return isinstance(expr, tvm.ir.op.Op) + + +class OverloadCounter(Counter): + def eligible(self, expr): + return isinstance(expr, relay.Function) and is_overload(expr) + + +class OpInOverloadCounter(Counter): + def eligible(self, expr): + return isinstance(expr, relay.Function) and is_overload(expr) + + def increment(self, expr): + return get_count_expr(OpCounter, expr) + + +def count_all_ops(mod): + return get_count_mod(OpCounter, mod) + + +def count_all_overloads(mod): + return get_count_mod(OverloadCounter, mod) + + +def count_all_ops_in_overloads(mod): + return get_count_mod(OpInOverloadCounter, mod) diff --git a/tests/python/relay/test_op_summary.py b/tests/python/relay/test_op_summary.py new file mode 100644 index 000000000..176d0f998 --- /dev/null +++ b/tests/python/relay/test_op_summary.py @@ -0,0 +1,98 @@ +import tvm +from tvm import relay +from tvm.relay.testing import count_all_ops, count_all_overloads, count_all_ops_in_overloads +from tvm.relay.testing import annotate_exact_matches, deduplicate_vars + +def test_count_chain(): + mod = tvm.IRModule() + x = relay.Var("x") + y = relay.Var("y") + z = relay.Var("z") + w = relay.Var("w") + mod["main"] = relay.Function([x, y, z, w], relay.nn.conv2d(x + z, w*y)) + assert count_all_ops(mod) == 3 + + +def test_count_multiple_funcs(): + mod = tvm.IRModule() + x = relay.Var("x") + y = relay.Var("y") + z = relay.Var("z") + w = relay.Var("w") + gv = relay.GlobalVar("f2") + mod["main"] = relay.Function([x, y, z, w], relay.nn.conv2d(x + z, gv(z, w))) + a = relay.Var("a") + b = relay.Var("b") + mod[gv] = relay.Function([a, b], a*b) + assert count_all_ops(mod) == 3 + + +def test_count_single_overload(): + x = relay.Var("x") + notnot = relay.logical_not(relay.logical_not(x)) + + mod = tvm.IRModule() + mod["main"] = annotate_exact_matches( + relay.Function([x], notnot), + deduplicate_vars(notnot), + "MyCompiler", "notnot") + + assert count_all_overloads(mod) == 1 + assert count_all_ops_in_overloads(mod) == 2 + + +def test_count_multiple_overloads(): + x = relay.Var("x") + y = relay.Var("y") + conv = relay.nn.conv2d(x, y) + add = x + y + + mod = tvm.IRModule() + a = relay.Var("a") + b = relay.Var("b") + c = relay.Var("c") + match_conv = annotate_exact_matches( + relay.Function([a, b, c], relay.nn.conv2d(a + b, c)), + conv, + "MyCompiler", "conv" + ) + match_add = annotate_exact_matches( + match_conv, + add, + "MyCompiler", "add") + mod["main"] = match_add + assert count_all_overloads(mod) == 2 + assert count_all_ops_in_overloads(mod) == 2 + + +def test_count_overloads_multiple_funcs(): + x, y, z = relay.Var("x"), relay.Var("y"), relay.Var("z") + linear_layer = relay.nn.bias_add(relay.nn.dense(x, y), z) + conv = relay.nn.conv2d(x, y) + + mod = tvm.IRModule() + + a, b, c = relay.Var("a"), relay.Var("b"), relay.Var("c") + lin_func = relay.Function([a, b, c], + relay.nn.bias_add(relay.nn.dense(a, b), c)) + match_lin = annotate_exact_matches(lin_func, linear_layer, "MyCompiler", "linear") + + linear_var = relay.GlobalVar("linear_layer") + mod[linear_var] = match_lin + + d, e, f, g = relay.Var("d"), relay.Var("e"), relay.Var("f"), relay.Var("g") + main_func = relay.Function([d, e, f, g], + relay.nn.conv2d(linear_var(d, e, f), g)) + match_conv = annotate_exact_matches(main_func, conv, "MyCompiler", "Conv") + mod["main"] = match_conv + + assert count_all_overloads(mod) == 2 + assert count_all_ops_in_overloads(mod) == 3 + + +if __name__ == "__main__": + test_count_chain() + test_count_multiple_funcs() + test_count_single_overload() + test_count_multiple_overloads() + test_count_overloads_multiple_funcs()