diff --git a/README.md b/README.md index 7dceb7e..4226d50 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ Dataset dataset = new Cora(); Matrix adjacency = dataset.graph().setMainDiagonal(1).setToSymmetricNormalization(); Matrix nodeFeatures = dataset.features(); Matrix nodeLabels = dataset.labels(); -Slice nodes = new Slice(0, nodeLabels.getRows()).shuffle(100); +Slice nodes = dataset.samples().getSlice().shuffle(100); long numClasses = nodeLabels.getCols(); ModelBuilder modelBuilder = new FastBuilder(adjacency, nodeFeatures) @@ -64,22 +64,23 @@ ModelTraining trainer = new ModelTraining() .setEpochs(300) .setPatience(100) .setLoss(new CategoricalCrossEntropy()) + .setVerbose(true) .setValidationLoss(new CategoricalCrossEntropy()); Model model = modelBuilder.getModel() .init(new XavierNormal()) .train(trainer, - nodes.samplesAsFeatures(), - nodeLabels, - nodes.range(0, 0.6), - nodes.range(0.6, 0.8)); + nodes.samplesAsFeatures(), + nodeLabels, + nodes.range(0, 0.6), + nodes.range(0.6, 0.8)); Matrix output = model.predict(nodes.samplesAsFeatures()).get(0).cast(Matrix.class); double acc = 0; for(Long node : nodes.range(0.8, 1)) { - Matrix nodeLabels = dataset.labels().accessRow(node).asRow(); + Matrix trueLabels = dataset.labels().accessRow(node).asRow(); Tensor nodeOutput = output.accessRow(node).asRow(); - acc += nodeOutput.argmax()==nodeLabels.argmax()?1:0; + acc += nodeOutput.argmax()==trueLabels.argmax()?1:0; } System.out.println("Acc\t "+acc/nodes.range(0.8, 1).size()); ```