-
Notifications
You must be signed in to change notification settings - Fork 75
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ Subgraph ] Example application for subgraph usecase
- This commit contains an example application to build a large neuralnet with a subgraph. - This application adds subgraph with `is_shared_subgraph=true` option. - The summarization result shows the layers are respectively shared. Signed-off-by: Eunju Yang <[email protected]>
- Loading branch information
Showing
5 changed files
with
23 additions
and
280 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
* | ||
* @file main.cpp | ||
* @date 27 Dec 2024 | ||
* @brief Test Application for shared_from | ||
* @brief Test Application for subgraph weight sharing | ||
* @see https://github.com/nnstreamer/nntrainer | ||
* @author Eunju Yang <[email protected]> | ||
* @bug No known bugs except for NYI items | ||
|
@@ -27,44 +27,24 @@ using LayerHandle = std::shared_ptr<ml::train::Layer>; | |
using ModelHandle = std::unique_ptr<ml::train::Model>; | ||
using UserDataType = std::unique_ptr<nntrainer::util::DataLoader>; | ||
|
||
/** | ||
* @brief tain data callback | ||
*/ | ||
int trainData_cb(float **input, float **label, bool *last, void *user_data) { | ||
auto data = reinterpret_cast<nntrainer::util::DataLoader *>(user_data); | ||
|
||
data->next(input, label, last); | ||
return 0; | ||
} | ||
|
||
/** | ||
* @brief Create subgraph | ||
* @return vector of layers that contain subgraph | ||
*/ | ||
std::vector<LayerHandle> createSubGraph(const std::string &scope, | ||
int subgraph_idx) { | ||
std::vector<LayerHandle> createSubGraph(const std::string &scope) { | ||
|
||
using ml::train::createLayer; | ||
|
||
std::vector<LayerHandle> layers; | ||
|
||
layers.push_back(createLayer( | ||
"fully_connected", | ||
{withKey("name", scope + "/fc_in" + std::to_string(subgraph_idx)), | ||
withKey("unit", 320), | ||
withKey("input_layers", "input/" + std::to_string(subgraph_idx)), | ||
withKey("shared_from", scope + "/fc_in0")})); | ||
layers.push_back(createLayer( | ||
"fully_connected", | ||
{ | ||
withKey("name", scope + "/fc_out" + std::to_string(subgraph_idx)), | ||
withKey("unit", 320), | ||
withKey("input_layers", scope + "/fc_in" + std::to_string(subgraph_idx)), | ||
withKey("shared_from", scope + "/fc_out0"), | ||
})); | ||
layers.push_back(createLayer( | ||
"identity", | ||
{withKey("name", "input/" + std::to_string(subgraph_idx + 1))})); | ||
layers.push_back(createLayer("fully_connected", { | ||
withKey("name", "fc_in"), | ||
withKey("unit", 320), | ||
})); | ||
layers.push_back(createLayer("fully_connected", { | ||
withKey("name", "fc_out"), | ||
withKey("unit", 320), | ||
})); | ||
|
||
return layers; | ||
} | ||
|
@@ -79,12 +59,16 @@ int main(int argc, char *argv[]) { | |
|
||
/** add input layer */ | ||
model->addLayer( | ||
ml::train::createLayer("input", {"name=input/0", "input_shape=1:1:320"})); | ||
ml::train::createLayer("input", {"name=input", "input_shape=1:1:320"})); | ||
|
||
/** create a subgraph structure */ | ||
auto subgraph = createSubGraph("subgraph"); | ||
|
||
/** add subgraphs with shared_from */ | ||
for (auto idx_sg = 0; idx_sg < n_sg; ++idx_sg) { | ||
for (auto &layer : createSubGraph(std::string("subgraph"), idx_sg)) | ||
model->addLayer(layer); | ||
for (unsigned int idx_sg = 0; idx_sg < n_sg; ++idx_sg) { | ||
model->addWithReferenceLayers( | ||
subgraph, "subgraph", {}, {"fc_in"}, {"fc_out"}, | ||
ml::train::ReferenceLayersType::SUBGRAPH, | ||
{withKey("subgraph_idx", idx_sg), withKey("is_shared_subgraph", "true")}); | ||
} | ||
|
||
auto optimizer = ml::train::createOptimizer("sgd", {"learning_rate=0.001"}); | ||
|
@@ -102,4 +86,4 @@ int main(int argc, char *argv[]) { | |
|
||
/** check weight sharing from summary */ | ||
model->summarize(std::cout, ML_TRAIN_SUMMARY_TENSOR); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,4 +32,4 @@ static std::string withKey(const std::string &key, | |
ss << *iter; | ||
|
||
return ss.str(); | ||
} | ||
} |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters