Skip to content

Commit

Permalink
finished training abstraction schemas
Browse files Browse the repository at this point in the history
  • Loading branch information
maniospas committed Aug 26, 2024
1 parent e3cdc59 commit 0404df0
Show file tree
Hide file tree
Showing 44 changed files with 962 additions and 1,794 deletions.
3 changes: 2 additions & 1 deletion JGNN/src/examples/classification/LogisticRegression.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import mklab.JGNN.adhoc.ModelBuilder;
import mklab.JGNN.adhoc.ModelTraining;
import mklab.JGNN.adhoc.datasets.Citeseer;
import mklab.JGNN.adhoc.train.SampleClassification;
import mklab.JGNN.core.Matrix;
import mklab.JGNN.nn.Model;
import mklab.JGNN.nn.loss.Accuracy;
Expand Down Expand Up @@ -44,7 +45,7 @@ public static void main(String[] args) {


long tic = System.currentTimeMillis();
Model model = new ModelTraining()
Model model = new SampleClassification()
.setOptimizer(new GradientDescent(0.01))
.setEpochs(600)
.setNumBatches(10)
Expand Down
22 changes: 14 additions & 8 deletions JGNN/src/examples/classification/MLP.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
import mklab.JGNN.adhoc.ModelBuilder;
import mklab.JGNN.adhoc.ModelTraining;
import mklab.JGNN.adhoc.datasets.Citeseer;
import mklab.JGNN.adhoc.train.SampleClassification;
import mklab.JGNN.core.Matrix;
import mklab.JGNN.nn.Model;
import mklab.JGNN.core.Slice;
import mklab.JGNN.core.Tensor;
import mklab.JGNN.nn.initializers.XavierNormal;
import mklab.JGNN.nn.loss.Accuracy;
import mklab.JGNN.nn.loss.BinaryCrossEntropy;
import mklab.JGNN.nn.loss.report.VerboseLoss;
import mklab.JGNN.nn.optimizers.Adam;

/**
Expand Down Expand Up @@ -42,20 +44,24 @@ public static void main(String[] args) {

Slice nodeIds = dataset.samples().getSlice().shuffle(100);

long tic = System.currentTimeMillis();
Model model = new ModelTraining()
Slice nodes = dataset.samples().getSlice().shuffle(100);
ModelTraining trainer = new SampleClassification()
.setFeatures(dataset.features())
.setOutputs(dataset.labels())
.setTrainingSamples(nodes.range(0, 0.6))
.setValidationSamples(nodes.range(0.6, 0.8))
.setOptimizer(new Adam(0.01))
.setEpochs(3000)
.setPatience(300)
.setNumBatches(20)
.setParallelizedStochasticGradientDescent(true)
.setLoss(new BinaryCrossEntropy())
.setVerbose(true)
.setValidationLoss(new Accuracy())
.train(new XavierNormal().apply(modelBuilder.getModel()),
dataset.features(),
dataset.labels(),
nodeIds.range(0, 0.7), nodeIds.range(0.7, 0.8));
.setValidationLoss(new VerboseLoss(new Accuracy()));

long tic = System.currentTimeMillis();
Model model = modelBuilder.getModel()
.init(new XavierNormal())
.train(trainer);
long toc = System.currentTimeMillis();

double acc = 0;
Expand Down
60 changes: 27 additions & 33 deletions JGNN/src/examples/graphClassification/SortPooling.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@
import java.util.Arrays;

import mklab.JGNN.adhoc.ModelBuilder;
import mklab.JGNN.adhoc.ModelTraining;
import mklab.JGNN.adhoc.parsers.LayeredBuilder;
import mklab.JGNN.adhoc.train.AGFTraining;
import mklab.JGNN.core.Matrix;
import mklab.JGNN.core.Tensor;
import mklab.JGNN.core.ThreadPool;
import mklab.JGNN.nn.Loss;
import mklab.JGNN.nn.Model;
import mklab.JGNN.nn.initializers.XavierNormal;
import mklab.JGNN.nn.loss.Accuracy;
import mklab.JGNN.nn.loss.CategoricalCrossEntropy;
import mklab.JGNN.nn.loss.report.VerboseLoss;
import mklab.JGNN.nn.optimizers.Adam;
import mklab.JGNN.nn.optimizers.BatchOptimizer;

Expand Down Expand Up @@ -45,40 +49,30 @@ public static void main(String[] args){
TrajectoryData dtrain = new TrajectoryData(8000);
TrajectoryData dtest = new TrajectoryData(2000);

Model model = builder.getModel().init(new XavierNormal());
BatchOptimizer optimizer = new BatchOptimizer(new Adam(0.01));
Loss loss = new CategoricalCrossEntropy();
for(int epoch=0; epoch<600; epoch++) {
// gradient update over all graphs
for(int graphId=0; graphId<dtrain.graphs.size(); graphId++) {
int graphIdentifier = graphId;
// each gradient calculation into a new thread pool task
ThreadPool.getInstance().submit(new Runnable() {
@Override
public void run() {
//System.out.println(dtrain.graphs.get(graphIdentifier).sum());
Matrix adjacency = dtrain.graphs.get(graphIdentifier);
Matrix features= dtrain.features.get(graphIdentifier);
Tensor graphLabel = dtrain.labels.get(graphIdentifier).asRow();

model.train(loss, optimizer,
Arrays.asList(features, adjacency),
Arrays.asList(graphLabel));
}
});
ThreadPool.getInstance().waitForConclusion(); // wait for all gradients to compute
}
optimizer.updateAll(); // apply gradients on model parameters
ModelTraining trainer = new AGFTraining()
.setGraphs(dtrain.graphs)
.setNodeFeatures(dtrain.features)
.setGraphLabels(dtrain.labels)
.setValidationSplit(0.2)
.setEpochs(300)
.setOptimizer(new Adam(0.001))
.setLoss(new CategoricalCrossEntropy())
//.setNumBatches(10)
//.setParallelizedStochasticGradientDescent(true)
.setValidationLoss(new VerboseLoss(new CategoricalCrossEntropy(), new Accuracy()));

Model model = builder.getModel()
.init(new XavierNormal())
.train(trainer);

double acc = 0.0;
for(int graphId=0; graphId<dtest.graphs.size(); graphId++) {
Matrix adjacency = dtest.graphs.get(graphId);
Matrix features= dtest.features.get(graphId);
Tensor graphLabel = dtest.labels.get(graphId);
if(model.predict(Arrays.asList(features, adjacency)).get(0).argmax()==graphLabel.argmax())
acc += 1;
}
System.out.println("iter = " + epoch + " " + acc/dtest.graphs.size());
double acc = 0.0;
for(int graphId=0; graphId<dtest.graphs.size(); graphId++) {
Matrix adjacency = dtest.graphs.get(graphId);
Matrix features= dtest.features.get(graphId);
Tensor graphLabel = dtest.labels.get(graphId);
if(model.predict(Arrays.asList(features, adjacency)).get(0).argmax()==graphLabel.argmax())
acc += 1;
}
System.out.println("Test accuracy " + acc/dtest.graphs.size());
}
}
84 changes: 84 additions & 0 deletions JGNN/src/examples/graphClassification/SortPoolingManual.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package graphClassification;

import java.util.Arrays;

import mklab.JGNN.adhoc.ModelBuilder;
import mklab.JGNN.adhoc.parsers.LayeredBuilder;
import mklab.JGNN.core.Matrix;
import mklab.JGNN.core.Tensor;
import mklab.JGNN.core.ThreadPool;
import mklab.JGNN.nn.Loss;
import mklab.JGNN.nn.Model;
import mklab.JGNN.nn.initializers.XavierNormal;
import mklab.JGNN.nn.loss.CategoricalCrossEntropy;
import mklab.JGNN.nn.optimizers.Adam;
import mklab.JGNN.nn.optimizers.BatchOptimizer;

/**
*
* @author github.com/gavalian
* @author Emmanouil Krasanakis
*/
public class SortPoolingManual {

public static void main(String[] args){
long reduced = 5; // input graphs need to have at least that many nodes, lower values decrease accuracy
long hidden = 8; // since this library does not use GPU parallelization, many latent dims reduce speed

ModelBuilder builder = new LayeredBuilder()
.var("A")
.config("features", 1)
.config("classes", 2)
.config("reduced", reduced)
.config("hidden", hidden)
.config("reg", 0.005)
.layer("h{l+1}=relu(A@(h{l}@matrix(features, hidden, reg))+vector(hidden))")
.layer("h{l+1}=relu(A@(h{l}@matrix(hidden, hidden, reg))+vector(hidden))")
.concat(2) // concatenates the outputs of the last 2 layers
.config("hiddenReduced", hidden*2*reduced) // 2* due to concatenation
.operation("z{l}=sort(h{l}, reduced)") // currently, the parser fails to understand full expressions within next step's gather, so we need to create this intermediate variable
.layer("h{l+1}=reshape(h{l}[z{l}], 1, hiddenReduced)") //
.layer("h{l+1}=h{l}@matrix(hiddenReduced, classes)")
.layer("h{l+1}=softmax(h{l}, dim: 'row')")
//.layer("h{l+1}=softmax(sum(h{l}@matrix(hiddenReduced, classes), row))")//this is mean pooling to replace the above sort pooling
.out("h{l}");
TrajectoryData dtrain = new TrajectoryData(8000);
TrajectoryData dtest = new TrajectoryData(2000);

Model model = builder.getModel().init(new XavierNormal());
BatchOptimizer optimizer = new BatchOptimizer(new Adam(0.01));
Loss loss = new CategoricalCrossEntropy();
for(int epoch=0; epoch<600; epoch++) {
// gradient update over all graphs
for(int graphId=0; graphId<dtrain.graphs.size(); graphId++) {
int graphIdentifier = graphId;
// each gradient calculation into a new thread pool task
ThreadPool.getInstance().submit(new Runnable() {
@Override
public void run() {
//System.out.println(dtrain.graphs.get(graphIdentifier).sum());
Matrix adjacency = dtrain.graphs.get(graphIdentifier);
Matrix features= dtrain.features.get(graphIdentifier);
Tensor graphLabel = dtrain.labels.get(graphIdentifier).asRow();

model.train(loss, optimizer,
Arrays.asList(features, adjacency),
Arrays.asList(graphLabel));
}
});
ThreadPool.getInstance().waitForConclusion(); // wait for all gradients to compute
}
optimizer.updateAll(); // apply gradients on model parameters

double acc = 0.0;
for(int graphId=0; graphId<dtest.graphs.size(); graphId++) {
Matrix adjacency = dtest.graphs.get(graphId);
Matrix features= dtest.features.get(graphId);
Tensor graphLabel = dtest.labels.get(graphId);
if(model.predict(Arrays.asList(features, adjacency)).get(0).argmax()==graphLabel.argmax())
acc += 1;
}
System.out.println("iter = " + epoch + " " + acc/dtest.graphs.size());
}
}
}
4 changes: 2 additions & 2 deletions JGNN/src/examples/graphClassification/TrajectoryData.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ protected void createGraphs(int numGraphs){
this.features.add(ft);
this.features.add(ff);

this.labels.add(new DenseTensor(2).put(0, 1.0));
this.labels.add(new DenseTensor(2).put(1, 1.0));
this.labels.add(new DenseTensor(2).put(0, 1.0).asRow());
this.labels.add(new DenseTensor(2).put(1, 1.0).asRow());
}
}
}
26 changes: 15 additions & 11 deletions JGNN/src/examples/nodeClassification/APPNP.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
import mklab.JGNN.adhoc.ModelTraining;
import mklab.JGNN.adhoc.datasets.Cora;
import mklab.JGNN.adhoc.parsers.FastBuilder;
import mklab.JGNN.adhoc.train.SampleClassification;
import mklab.JGNN.core.Matrix;
import mklab.JGNN.nn.Model;
import mklab.JGNN.core.Slice;
import mklab.JGNN.core.Tensor;
import mklab.JGNN.nn.initializers.XavierNormal;
import mklab.JGNN.nn.loss.Accuracy;
import mklab.JGNN.nn.loss.CategoricalCrossEntropy;
import mklab.JGNN.nn.loss.report.VerboseLoss;
import mklab.JGNN.nn.optimizers.Adam;

/**
Expand All @@ -25,7 +28,7 @@ public static void main(String[] args) throws Exception {
dataset.graph().setMainDiagonal(1).setToSymmetricNormalization();
dataset.graph().setDimensionName("nodes", "nodes");
dataset.features().setDimensionName("nodes", "features");
dataset.labels().setDimensionName("nodes", "labels");
dataset.labels().setDimensionName("nodes", "classes");

long numClasses = dataset.labels().getCols();
ModelBuilder modelBuilder = new FastBuilder(dataset.graph(), dataset.features())
Expand All @@ -38,24 +41,25 @@ public static void main(String[] args) throws Exception {
.constant("a", 0.9)
.layerRepeat("h{l+1} = a*(dropout(A, 0.5)@h{l})+(1-a)*h{0}", 10)
.classify();

ModelTraining trainer = new ModelTraining()

Slice nodes = dataset.samples().getSlice().shuffle(100);
ModelTraining trainer = new SampleClassification()
// set data
.setFeatures(nodes.samplesAsFeatures())
.setOutputs(dataset.labels())
.setTrainingSamples(nodes.range(0, 0.6))
.setValidationSamples(nodes.range(0.6, 0.8))
// configure how training is conducted
.setOptimizer(new Adam(0.01))
.setEpochs(300)
.setPatience(100)
.setVerbose(true)
.setLoss(new CategoricalCrossEntropy())
.setValidationLoss(new CategoricalCrossEntropy());
.setValidationLoss(new VerboseLoss(new CategoricalCrossEntropy(), new Accuracy()).setInterval(10));

long tic = System.currentTimeMillis();
Slice nodes = dataset.samples().getSlice().shuffle(100);
Model model = modelBuilder.getModel()
.init(new XavierNormal())
.train(trainer,
nodes.samplesAsFeatures(),
dataset.labels(),
nodes.range(0, 0.6),
nodes.range(0.6, 0.8));
.train(trainer);

System.out.println("Training time "+(System.currentTimeMillis()-tic)/1000.);
Matrix output = model.predict(nodes.samplesAsFeatures()).get(0).cast(Matrix.class);
Expand Down
22 changes: 12 additions & 10 deletions JGNN/src/examples/nodeClassification/GAT.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import mklab.JGNN.adhoc.ModelTraining;
import mklab.JGNN.adhoc.datasets.Citeseer;
import mklab.JGNN.adhoc.parsers.FastBuilder;
import mklab.JGNN.adhoc.train.NodeClassification;
import mklab.JGNN.adhoc.train.SampleClassification;
import mklab.JGNN.core.Matrix;
import mklab.JGNN.nn.Model;
import mklab.JGNN.core.Slice;
import mklab.JGNN.core.Tensor;
import mklab.JGNN.nn.initializers.XavierNormal;
import mklab.JGNN.nn.loss.Accuracy;
import mklab.JGNN.nn.loss.CategoricalCrossEntropy;
import mklab.JGNN.nn.loss.report.VerboseLoss;
import mklab.JGNN.nn.optimizers.Adam;

public class GAT {
Expand All @@ -35,23 +36,24 @@ public static void main(String[] args) throws Exception {
.classify()
.assertBackwardValidity();

ModelTraining trainer = new NodeClassification()
Slice nodes = dataset.samples().getSlice().shuffle(100);
ModelTraining trainer = new SampleClassification()
// set data
.setFeatures(nodes.samplesAsFeatures())
.setOutputs(dataset.labels())
.setTrainingSamples(nodes.range(0, 0.6))
.setValidationSamples(nodes.range(0.6, 0.8))
// configure how training is conducted
.setOptimizer(new Adam(0.01))
.setEpochs(300)
.setPatience(100)
.setVerbose(true)
.setLoss(new CategoricalCrossEntropy())
.setValidationLoss(new Accuracy());
.setValidationLoss(new VerboseLoss(new CategoricalCrossEntropy(), new Accuracy()));

long tic = System.currentTimeMillis();
Slice nodes = dataset.samples().getSlice().shuffle(100);
Model model = modelBuilder.getModel()
.init(new XavierNormal())
.train(trainer,
Tensor.fromRange(nodes.size()).asColumn(),
dataset.labels(),
nodes.range(0, 0.6),
nodes.range(0.6, 0.8));
.train(trainer);

System.out.println("Training time "+(System.currentTimeMillis()-tic)/1000.);
Matrix output = model.predict(Tensor.fromRange(0, nodes.size()).asColumn()).get(0).cast(Matrix.class);
Expand Down
Loading

0 comments on commit 0404df0

Please sign in to comment.