Skip to content

Commit

Permalink
local overwrite in neurlang
Browse files Browse the repository at this point in the history
  • Loading branch information
maniospas committed May 27, 2024
1 parent d63f0e5 commit a135e17
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 188 deletions.
49 changes: 0 additions & 49 deletions JGNN/neural_graph.jggn

This file was deleted.

108 changes: 0 additions & 108 deletions JGNN/src/examples/graphClassification/GCNlang.java

This file was deleted.

13 changes: 1 addition & 12 deletions JGNN/src/examples/nodeClassification/GCN.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public static void main(String[] args) throws Exception {
dataset.graph().setMainDiagonal(1).setToSymmetricNormalization();

long numClasses = dataset.labels().getCols();
/*ModelBuilder modelBuilder = new FastBuilder(dataset.graph(), dataset.features())
ModelBuilder modelBuilder = new FastBuilder(dataset.graph(), dataset.features())
.config("reg", 0.005)
.config("classes", numClasses)
.config("hidden", numClasses)
Expand All @@ -39,17 +39,6 @@ public static void main(String[] args) throws Exception {
.layer("h{l+1}=gcnlayer(A, h{l})")
.classify()
.autosize(new EmptyTensor(dataset.samples().getSlice().size()));
*/
ModelBuilder modelBuilder = new TextBuilder()
.parse(String.join("\n", Files.readAllLines(Paths.get("../architectures.nn"))))
.constant("A", dataset.graph())
.constant("h0", dataset.features())
.var("nodes")
.config("classes", numClasses)
.config("hidden", numClasses)
.operation("h=gcn(A,h0); out=softmax(h[nodes], row)")
.out("out")
.autosize(new EmptyTensor(dataset.samples().getSlice().size()));

ModelTraining trainer = new ModelTraining()
.setOptimizer(new Adam(0.01))
Expand Down
69 changes: 69 additions & 0 deletions JGNN/src/examples/nodeClassification/Scripting.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package nodeClassification;

import java.nio.file.Files;
import java.nio.file.Paths;

import mklab.JGNN.adhoc.Dataset;
import mklab.JGNN.adhoc.ModelBuilder;
import mklab.JGNN.adhoc.datasets.Cora;
import mklab.JGNN.adhoc.parsers.TextBuilder;
import mklab.JGNN.core.Matrix;
import mklab.JGNN.nn.Model;
import mklab.JGNN.nn.ModelTraining;
import mklab.JGNN.core.Slice;
import mklab.JGNN.core.Tensor;
import mklab.JGNN.core.empy.EmptyTensor;
import mklab.JGNN.nn.initializers.XavierNormal;
import mklab.JGNN.nn.loss.CategoricalCrossEntropy;
import mklab.JGNN.nn.optimizers.Adam;

/**
* Demonstrates classification with an architecture defined through the scripting engine.
*
* @author Emmanouil Krasanakis
*/
public class Scripting {
public static void main(String[] args) throws Exception {
Dataset dataset = new Cora();
dataset.graph().setMainDiagonal(1).setToSymmetricNormalization();
long numClasses = dataset.labels().getCols();

ModelBuilder modelBuilder = new TextBuilder()
.parse(String.join("\n", Files.readAllLines(Paths.get("../architectures.nn"))))
.constant("A", dataset.graph())
.constant("h", dataset.features())
.var("nodes")
.config("classes", numClasses)
.config("hidden", numClasses)
.out("classify(nodes, gcn(A,h))");
System.out.println(modelBuilder.getExecutionGraphDot());
modelBuilder
.autosize(new EmptyTensor(dataset.samples().getSlice().size()));

ModelTraining trainer = new ModelTraining()
.setOptimizer(new Adam(modelBuilder.getConfigOrDefault("lr", 0.01)))
.setEpochs(modelBuilder.getConfigOrDefault("epochs", 1000))
.setPatience(modelBuilder.getConfigOrDefault("patience", 100))
.setVerbose(true)
.setLoss(new CategoricalCrossEntropy())
.setValidationLoss(new CategoricalCrossEntropy());

long tic = System.currentTimeMillis();
Slice nodes = dataset.samples().getSlice().shuffle(100);
Model model = modelBuilder.getModel()
.init(new XavierNormal())
.train(trainer,
Tensor.fromRange(nodes.size()).asColumn(),
dataset.labels(), nodes.range(0, 0.6), nodes.range(0.6, 0.8));

System.out.println("Training time "+(System.currentTimeMillis()-tic)/1000.);
Matrix output = model.predict(Tensor.fromRange(0, nodes.size()).asColumn()).get(0).cast(Matrix.class);
double acc = 0;
for(Long node : nodes.range(0.8, 1)) {
Matrix nodeLabels = dataset.labels().accessRow(node).asRow();
Tensor nodeOutput = output.accessRow(node).asRow();
acc += nodeOutput.argmax()==nodeLabels.argmax()?1:0;
}
System.out.println("Acc\t "+acc/nodes.range(0.8, 1).size());
}
}
Loading

0 comments on commit a135e17

Please sign in to comment.