diff --git a/JGNN/src/examples/nodeClassification/GCN.java b/JGNN/src/examples/nodeClassification/GCN.java index b829a36..e19eabe 100644 --- a/JGNN/src/examples/nodeClassification/GCN.java +++ b/JGNN/src/examples/nodeClassification/GCN.java @@ -7,7 +7,6 @@ import mklab.JGNN.adhoc.ModelBuilder; import mklab.JGNN.adhoc.datasets.Cora; import mklab.JGNN.adhoc.parsers.FastBuilder; -import mklab.JGNN.adhoc.parsers.TextBuilder; import mklab.JGNN.core.Matrix; import mklab.JGNN.nn.Model; import mklab.JGNN.nn.ModelTraining; diff --git a/JGNN/src/examples/nodeClassification/Scripting.java b/JGNN/src/examples/nodeClassification/Scripting.java index cf15e5b..ae5b850 100644 --- a/JGNN/src/examples/nodeClassification/Scripting.java +++ b/JGNN/src/examples/nodeClassification/Scripting.java @@ -1,12 +1,11 @@ package nodeClassification; -import java.nio.file.Files; import java.nio.file.Paths; import mklab.JGNN.adhoc.Dataset; import mklab.JGNN.adhoc.ModelBuilder; import mklab.JGNN.adhoc.datasets.Cora; -import mklab.JGNN.adhoc.parsers.TextBuilder; +import mklab.JGNN.adhoc.parsers.Neuralang; import mklab.JGNN.core.Matrix; import mklab.JGNN.nn.Model; import mklab.JGNN.nn.ModelTraining; @@ -15,7 +14,6 @@ import mklab.JGNN.core.empy.EmptyTensor; import mklab.JGNN.nn.initializers.XavierNormal; import mklab.JGNN.nn.loss.CategoricalCrossEntropy; -import mklab.JGNN.nn.optimizers.Adam; /** * Demonstrates classification with an architecture defined through the scripting engine. @@ -26,24 +24,19 @@ public class Scripting { public static void main(String[] args) throws Exception { Dataset dataset = new Cora(); dataset.graph().setMainDiagonal(1).setToSymmetricNormalization(); - long numClasses = dataset.labels().getCols(); - ModelBuilder modelBuilder = new TextBuilder() - .parse(String.join("\n", Files.readAllLines(Paths.get("../architectures.nn")))) + ModelBuilder modelBuilder = new Neuralang() + .parse(Paths.get("../architectures.nn")) .constant("A", dataset.graph()) .constant("h", dataset.features()) .var("nodes") - .config("classes", numClasses) - .config("hidden", numClasses) - .out("classify(nodes, gcn(A,h))"); - System.out.println(modelBuilder.getExecutionGraphDot()); - modelBuilder + .config("classes", dataset.labels().getCols()) + .config("hidden", 16) + .out("classify(nodes, gcn(A,h))") .autosize(new EmptyTensor(dataset.samples().getSlice().size())); ModelTraining trainer = new ModelTraining() - .setOptimizer(new Adam(modelBuilder.getConfigOrDefault("lr", 0.01))) - .setEpochs(modelBuilder.getConfigOrDefault("epochs", 1000)) - .setPatience(modelBuilder.getConfigOrDefault("patience", 100)) + .configFrom(modelBuilder) .setVerbose(true) .setLoss(new CategoricalCrossEntropy()) .setValidationLoss(new CategoricalCrossEntropy()); diff --git a/JGNN/src/main/java/mklab/JGNN/adhoc/ModelBuilder.java b/JGNN/src/main/java/mklab/JGNN/adhoc/ModelBuilder.java index 8bb9399..2d3017f 100644 --- a/JGNN/src/main/java/mklab/JGNN/adhoc/ModelBuilder.java +++ b/JGNN/src/main/java/mklab/JGNN/adhoc/ModelBuilder.java @@ -355,7 +355,7 @@ public ModelBuilder param(String name, double regularization, Tensor value) { * @see #param(String, double, Tensor) */ public ModelBuilder config(String name, double value) { - configurations.put(name, value); + this.configurations.put(name, value); return this; } @@ -908,7 +908,7 @@ else if(functions.containsKey(splt[2])) { this.configurations = new HashMap(this.configurations); HashMap customNames = new HashMap(); for(int i=0;i tokens = extractTokens(functions.get(splt[2])); HashSet keywords = new HashSet(); keywords.addAll(functions.keySet()); @@ -970,7 +970,6 @@ else if(!keywords.contains(token) && !isNumeric(token) && !prevHash) if(!prevHash && !prevTemp) newExpr += token; } - customNames.putAll(renameLater); this.operation(newExpr); this.configurations = configStack; return this; @@ -978,8 +977,16 @@ else if(!keywords.contains(token) && !isNumeric(token) && !prevHash) else throw new RuntimeException("Invalid operation: "+desc); - if(arg0.contains(":")) + if(arg0.contains(":")) { + String config = arg0.substring(0, arg0.indexOf(":")).trim(); + String value = arg0.substring(arg0.indexOf(":")+1).trim(); + if(value.equals("extern")) { + if(!this.configurations.containsKey(config)) + throw new RuntimeException("Required external config: "+config); + } + this.config(config, parseConfigValue(value)); return this; + } if(arg0!=null) { assertExists(arg0); diff --git a/JGNN/src/main/java/mklab/JGNN/adhoc/parsers/TextBuilder.java b/JGNN/src/main/java/mklab/JGNN/adhoc/parsers/Neuralang.java similarity index 64% rename from JGNN/src/main/java/mklab/JGNN/adhoc/parsers/TextBuilder.java rename to JGNN/src/main/java/mklab/JGNN/adhoc/parsers/Neuralang.java index ee2a2b1..ef7c4fa 100644 --- a/JGNN/src/main/java/mklab/JGNN/adhoc/parsers/TextBuilder.java +++ b/JGNN/src/main/java/mklab/JGNN/adhoc/parsers/Neuralang.java @@ -1,16 +1,30 @@ package mklab.JGNN.adhoc.parsers; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + import mklab.JGNN.adhoc.ModelBuilder; import mklab.JGNN.core.Tensor; -public class TextBuilder extends ModelBuilder { - public TextBuilder() { +public class Neuralang extends ModelBuilder { + public Neuralang() { } - public TextBuilder config(String name, double value) { + public Neuralang config(String name, double value) { super.config(name, value); return this; } - public TextBuilder parse(String text) { + public Neuralang parse(Path path) { + try { + parse(String.join("\n", Files.readAllLines(path))); + } catch (IOException e) { + e.printStackTrace(); + } + return this; + } + + public Neuralang parse(String text) { int depth = 0; String progress = ""; for(int i=0;i