From 0404df02cbb9e39631b8ed12853737b244234a27 Mon Sep 17 00:00:00 2001 From: maniospas Date: Tue, 27 Aug 2024 01:34:16 +0300 Subject: [PATCH] finished training abstraction schemas --- .../classification/LogisticRegression.java | 3 +- JGNN/src/examples/classification/MLP.java | 22 +- .../graphClassification/SortPooling.java | 60 ++-- .../SortPoolingManual.java | 84 +++++ .../graphClassification/TrajectoryData.java | 4 +- .../examples/nodeClassification/APPNP.java | 26 +- JGNN/src/examples/nodeClassification/GAT.java | 22 +- JGNN/src/examples/nodeClassification/GCN.java | 23 +- .../examples/nodeClassification/GCNII.java | 25 +- .../examples/nodeClassification/HetGCN.java | 23 +- .../nodeClassification/MessagePassing.java | 23 +- .../nodeClassification/Scripting.java | 18 +- JGNN/src/examples/tutorial/Learning.java | 3 +- JGNN/src/examples/tutorial/NN.java | 26 +- JGNN/src/examples/tutorial/Quickstart.java | 3 +- .../main/java/mklab/JGNN/adhoc/BatchData.java | 20 ++ .../java/mklab/JGNN/adhoc/ModelTraining.java | 144 +++++++- .../JGNN/adhoc/datasets/package-info.java | 2 +- .../mklab/JGNN/adhoc/train/AGFTraining.java | 97 ++++++ .../JGNN/adhoc/train/NodeClassification.java | 114 ------- .../adhoc/train/SampleClassification.java | 135 ++++++++ .../src/main/java/mklab/JGNN/core/Matrix.java | 6 +- JGNN/src/main/java/mklab/JGNN/core/Slice.java | 2 +- .../src/main/java/mklab/JGNN/core/Tensor.java | 4 +- .../mklab/JGNN/core/matrix/DenseMatrix.java | 5 +- .../JGNN/core/matrix/VectorizedMatrix.java | 2 +- .../main/java/mklab/JGNN/core/util/Range.java | 60 ++-- JGNN/src/main/java/mklab/JGNN/nn/Loss.java | 15 + JGNN/src/main/java/mklab/JGNN/nn/Model.java | 307 +++++++++++------- .../mklab/JGNN/nn/inputs/package-info.java | 2 +- .../JGNN/nn/loss/report/VerboseLoss.java | 54 ++- docs/index.html | 82 ++--- tutorials/Data.md | 105 ------ tutorials/Debugging.md | 149 --------- tutorials/GNN.md | 185 ----------- tutorials/GraphClassification.md | 190 ----------- tutorials/Learning.md | 135 -------- tutorials/Message.md | 85 ----- tutorials/Models.md | 142 -------- tutorials/NN.md | 110 ------- tutorials/Neuralang.md | 77 ----- tutorials/Primitives.md | 145 --------- tutorials/README.md | 17 - tutorials/graphviz.png | Bin 36910 -> 0 bytes 44 files changed, 962 insertions(+), 1794 deletions(-) create mode 100644 JGNN/src/examples/graphClassification/SortPoolingManual.java create mode 100644 JGNN/src/main/java/mklab/JGNN/adhoc/BatchData.java create mode 100644 JGNN/src/main/java/mklab/JGNN/adhoc/train/AGFTraining.java delete mode 100644 JGNN/src/main/java/mklab/JGNN/adhoc/train/NodeClassification.java create mode 100644 JGNN/src/main/java/mklab/JGNN/adhoc/train/SampleClassification.java delete mode 100644 tutorials/Data.md delete mode 100644 tutorials/Debugging.md delete mode 100644 tutorials/GNN.md delete mode 100644 tutorials/GraphClassification.md delete mode 100644 tutorials/Learning.md delete mode 100644 tutorials/Message.md delete mode 100644 tutorials/Models.md delete mode 100644 tutorials/NN.md delete mode 100644 tutorials/Neuralang.md delete mode 100644 tutorials/Primitives.md delete mode 100644 tutorials/README.md delete mode 100644 tutorials/graphviz.png diff --git a/JGNN/src/examples/classification/LogisticRegression.java b/JGNN/src/examples/classification/LogisticRegression.java index 54a0b8e2..83571724 100644 --- a/JGNN/src/examples/classification/LogisticRegression.java +++ b/JGNN/src/examples/classification/LogisticRegression.java @@ -4,6 +4,7 @@ import mklab.JGNN.adhoc.ModelBuilder; import mklab.JGNN.adhoc.ModelTraining; import mklab.JGNN.adhoc.datasets.Citeseer; +import mklab.JGNN.adhoc.train.SampleClassification; import mklab.JGNN.core.Matrix; import mklab.JGNN.nn.Model; import mklab.JGNN.nn.loss.Accuracy; @@ -44,7 +45,7 @@ public static void main(String[] args) { long tic = System.currentTimeMillis(); - Model model = new ModelTraining() + Model model = new SampleClassification() .setOptimizer(new GradientDescent(0.01)) .setEpochs(600) .setNumBatches(10) diff --git a/JGNN/src/examples/classification/MLP.java b/JGNN/src/examples/classification/MLP.java index 093f2fbf..55fd55f2 100644 --- a/JGNN/src/examples/classification/MLP.java +++ b/JGNN/src/examples/classification/MLP.java @@ -4,6 +4,7 @@ import mklab.JGNN.adhoc.ModelBuilder; import mklab.JGNN.adhoc.ModelTraining; import mklab.JGNN.adhoc.datasets.Citeseer; +import mklab.JGNN.adhoc.train.SampleClassification; import mklab.JGNN.core.Matrix; import mklab.JGNN.nn.Model; import mklab.JGNN.core.Slice; @@ -11,6 +12,7 @@ import mklab.JGNN.nn.initializers.XavierNormal; import mklab.JGNN.nn.loss.Accuracy; import mklab.JGNN.nn.loss.BinaryCrossEntropy; +import mklab.JGNN.nn.loss.report.VerboseLoss; import mklab.JGNN.nn.optimizers.Adam; /** @@ -42,20 +44,24 @@ public static void main(String[] args) { Slice nodeIds = dataset.samples().getSlice().shuffle(100); - long tic = System.currentTimeMillis(); - Model model = new ModelTraining() + Slice nodes = dataset.samples().getSlice().shuffle(100); + ModelTraining trainer = new SampleClassification() + .setFeatures(dataset.features()) + .setOutputs(dataset.labels()) + .setTrainingSamples(nodes.range(0, 0.6)) + .setValidationSamples(nodes.range(0.6, 0.8)) .setOptimizer(new Adam(0.01)) .setEpochs(3000) .setPatience(300) .setNumBatches(20) .setParallelizedStochasticGradientDescent(true) .setLoss(new BinaryCrossEntropy()) - .setVerbose(true) - .setValidationLoss(new Accuracy()) - .train(new XavierNormal().apply(modelBuilder.getModel()), - dataset.features(), - dataset.labels(), - nodeIds.range(0, 0.7), nodeIds.range(0.7, 0.8)); + .setValidationLoss(new VerboseLoss(new Accuracy())); + + long tic = System.currentTimeMillis(); + Model model = modelBuilder.getModel() + .init(new XavierNormal()) + .train(trainer); long toc = System.currentTimeMillis(); double acc = 0; diff --git a/JGNN/src/examples/graphClassification/SortPooling.java b/JGNN/src/examples/graphClassification/SortPooling.java index 2dfa8ae5..b9f23067 100644 --- a/JGNN/src/examples/graphClassification/SortPooling.java +++ b/JGNN/src/examples/graphClassification/SortPooling.java @@ -3,14 +3,18 @@ import java.util.Arrays; import mklab.JGNN.adhoc.ModelBuilder; +import mklab.JGNN.adhoc.ModelTraining; import mklab.JGNN.adhoc.parsers.LayeredBuilder; +import mklab.JGNN.adhoc.train.AGFTraining; import mklab.JGNN.core.Matrix; import mklab.JGNN.core.Tensor; import mklab.JGNN.core.ThreadPool; import mklab.JGNN.nn.Loss; import mklab.JGNN.nn.Model; import mklab.JGNN.nn.initializers.XavierNormal; +import mklab.JGNN.nn.loss.Accuracy; import mklab.JGNN.nn.loss.CategoricalCrossEntropy; +import mklab.JGNN.nn.loss.report.VerboseLoss; import mklab.JGNN.nn.optimizers.Adam; import mklab.JGNN.nn.optimizers.BatchOptimizer; @@ -45,40 +49,30 @@ public static void main(String[] args){ TrajectoryData dtrain = new TrajectoryData(8000); TrajectoryData dtest = new TrajectoryData(2000); - Model model = builder.getModel().init(new XavierNormal()); - BatchOptimizer optimizer = new BatchOptimizer(new Adam(0.01)); - Loss loss = new CategoricalCrossEntropy(); - for(int epoch=0; epoch<600; epoch++) { - // gradient update over all graphs - for(int graphId=0; graphId