Skip to content

Commit

Permalink
Working quickstart
Browse files Browse the repository at this point in the history
  • Loading branch information
maniospas authored Jul 23, 2024
1 parent 68dcf2c commit 7370a08
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Dataset dataset = new Cora();
Matrix adjacency = dataset.graph().setMainDiagonal(1).setToSymmetricNormalization();
Matrix nodeFeatures = dataset.features();
Matrix nodeLabels = dataset.labels();
Slice nodes = new Slice(0, nodeLabels.getRows()).shuffle(100);
Slice nodes = dataset.samples().getSlice().shuffle(100);
long numClasses = nodeLabels.getCols();

ModelBuilder modelBuilder = new FastBuilder(adjacency, nodeFeatures)
Expand All @@ -64,22 +64,23 @@ ModelTraining trainer = new ModelTraining()
.setEpochs(300)
.setPatience(100)
.setLoss(new CategoricalCrossEntropy())
.setVerbose(true)
.setValidationLoss(new CategoricalCrossEntropy());

Model model = modelBuilder.getModel()
.init(new XavierNormal())
.train(trainer,
nodes.samplesAsFeatures(),
nodeLabels,
nodes.range(0, 0.6),
nodes.range(0.6, 0.8));
nodes.samplesAsFeatures(),
nodeLabels,
nodes.range(0, 0.6),
nodes.range(0.6, 0.8));

Matrix output = model.predict(nodes.samplesAsFeatures()).get(0).cast(Matrix.class);
double acc = 0;
for(Long node : nodes.range(0.8, 1)) {
Matrix nodeLabels = dataset.labels().accessRow(node).asRow();
Matrix trueLabels = dataset.labels().accessRow(node).asRow();
Tensor nodeOutput = output.accessRow(node).asRow();
acc += nodeOutput.argmax()==nodeLabels.argmax()?1:0;
acc += nodeOutput.argmax()==trueLabels.argmax()?1:0;
}
System.out.println("Acc\t "+acc/nodes.range(0.8, 1).size());
```
Expand Down

0 comments on commit 7370a08

Please sign in to comment.