Skip to content

Commit

Permalink
docs 85% complete
Browse files Browse the repository at this point in the history
  • Loading branch information
maniospas committed Aug 22, 2024
1 parent ccb9b71 commit 8e50e0e
Show file tree
Hide file tree
Showing 6 changed files with 641 additions and 784 deletions.
17 changes: 6 additions & 11 deletions JGNN/src/examples/nodeClassification/Scripting.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import mklab.JGNN.core.empy.EmptyTensor;
import mklab.JGNN.nn.initializers.XavierNormal;
import mklab.JGNN.nn.loss.CategoricalCrossEntropy;
import mklab.JGNN.nn.loss.report.VerboseLoss;
import mklab.JGNN.nn.optimizers.Adam;

/**
Expand All @@ -31,18 +32,12 @@ fn classify(nodes, h, epochs: !3000, patience: !100, lr: !0.01) {
return softmax(h[nodes], dim: "row");
}
fn gcnlayer(A, h, hidden: 16, reg: 0.005) {
h = A@h@matrix(?, hidden, reg) + vector(hidden);
return h;
return A@h@matrix(?, hidden, reg) + vector(hidden);
}
fn gcn(A, h, classes: extern) {
h = gcnlayer(A, h);
h = dropout(relu(h), 0.5);
h = gcnlayer(A, h, hidden: classes);
return h;
}
fn ngcn(A, h, nodes) {
h = classify(nodes, gcn(A,h));
return h;
return gcnlayer(A, h, hidden: classes);
}
""";

Expand All @@ -55,14 +50,14 @@ fn ngcn(A, h, nodes) {
.var("nodes")
.config("classes", numClasses)
.config("hidden", numClasses+2)
.out("ngcn(A,h, nodes)")
.out("classify(nodes, gcn(A,h))")
.autosize(new EmptyTensor(numSamples));
System.out.println("Preferred learning rate: "+modelBuilder.getConfig("lr"));

ModelTraining trainer = new ModelTraining()
.configFrom(modelBuilder)
.setVerbose(true)
.setLoss(new CategoricalCrossEntropy())
.setValidationLoss(new CategoricalCrossEntropy());
.setValidationLoss(new VerboseLoss(new CategoricalCrossEntropy()));

long tic = System.currentTimeMillis();
Slice nodes = dataset.samples().getSlice().shuffle(100);
Expand Down
2 changes: 2 additions & 0 deletions JGNN/src/examples/tutorial/Quickstart.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ public static void main(String[] args) throws Exception {
.classify()
.autosize(new EmptyTensor(numSamples));

System.out.println(modelBuilder.getConfig("lr"));

ModelTraining trainer = new ModelTraining()
.setOptimizer(new Adam(0.01))
.setEpochs(3000)
Expand Down
58 changes: 48 additions & 10 deletions JGNN/src/main/java/mklab/JGNN/adhoc/ModelBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ public ModelBuilder(Model model) {
public Model getModel() {
return model;
}

/**
* Serializes the model builder instance into a Path, such as
* <code>Paths.get("example.jgnn")</code>.
* @param path A serialized path.
* @return This builder's instance.
*/
public ModelBuilder save(Path path) {
try(BufferedWriter writer = Files.newBufferedWriter(path)){
writer.write(this.getClass().getCanonicalName()+"\n");
Expand Down Expand Up @@ -140,6 +147,13 @@ public ModelBuilder save(Path path) {
return this;
}

/**
* Loads a ModelBuilder instance from the provided path, such as <code>Paths.get("example.jgnn")</code>.
* The instance may have been serialized with any class that extends the model builder.
*
* @param path The provided path.
* @return The loaded ModelBuilder instance.
*/
public static ModelBuilder load(Path path) {
ModelBuilder builder;
try(BufferedReader reader = Files.newBufferedReader(path)){
Expand Down Expand Up @@ -350,14 +364,16 @@ public ModelBuilder param(String name, double regularization, Tensor value) {
/**
* Declares a configuration hyperparameter, which can be used to declare
* matrix and vector parameters during {@link #operation(String)} expressions.
* For in-expression use of hyperparameters, delcare them with {@link #constant(String, double)}.
* For in-expression use of hyperparameters, declare them with {@link #constant(String, double)}.
* In Neuralang terms, this is implements the broadcasting operation.
* @param name The name of the configuration hyperparameter.
* @param value The value to be assigned to the hyperparameter.
* Typically, provide a long number.
* This may also be a long number.
* @return The builder's instance.
* @see #operation(String)
* @see #param(String, Tensor)
* @see #param(String, double, Tensor)
* @see #config(String, String)
*/
public ModelBuilder config(String name, double value) {
if(name.equals("?"))
Expand All @@ -366,19 +382,41 @@ public ModelBuilder config(String name, double value) {
return this;
}


/**
* Applies {@link #config(String, double)} where the set value
* is obtained from another configuration hyperaparameter.
* @param name The name of the configuration hyperparameter to set.
* @param value The name of the configuration hyperparameter whose value should be copied.
* @return The builder's instance.
* @see #config(String, double)
*/
public ModelBuilder config(String name, String value) {
Double val = configurations.get(value);
return config(name, getConfig(value));
}

/**
* Retrieves a configuration hyperparameter's value.
* @param name The configuration's name.
* @return The retrieved value;
* @throws RuntimeException If a no configuration with the given name was found.
* @see #getConfigOrDefault(String, double)
*/
public double getConfig(String name) {
Double val = configurations.get(name);
if(val==null)
throw new RuntimeException("No configuration "+value+" found");
throw new RuntimeException("No configuration "+name+" found");
this.configurations.put(name, val);
return this;
}

public int getConfigOrDefault(String name, int defaultValue) {
return (int)(double)configurations.getOrDefault(name, (double) defaultValue);
return val;
}

/**
* Retrieves a configuration hyperparameter's value. If no such configuration
* exists, a default value is returned instead.
* @param name The configuration's name.
* @param defaultValue The default to be retrieved if no such configuration was found.
* @return The retrieved value;
* @see #getConfig(String)
*/
public double getConfigOrDefault(String name, double defaultValue) {
return configurations.getOrDefault(name, defaultValue);
}
Expand Down
25 changes: 23 additions & 2 deletions JGNN/src/main/java/mklab/JGNN/adhoc/parsers/Neuralang.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,33 @@
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;

import mklab.JGNN.adhoc.ModelBuilder;
import mklab.JGNN.core.Tensor;

/**
* Extends the base {@link ModelBuilder} with the full capabilities of the Neuralang
* scripting language.
*
* @author Emmanouil Krasanakis
* @see #parse(String)
* @see #parse(Path)
*/
public class Neuralang extends ModelBuilder {
public Neuralang() {
}
public Neuralang config(String name, double value) {
super.config(name, value);
return this;
}
/**
* Parses a Neuralang source code file.
* Reads a file like <code>Paths.get("models.nn")</code>
* from disk with {@link Files#readAllLines(Path)}, and parses
* the loaded String.
* @param path The source code file.
* @return The Neuralang builder's instance.
* @see #parse(String)
*/
public Neuralang parse(Path path) {
try {
parse(String.join("\n", Files.readAllLines(path)));
Expand All @@ -24,6 +39,12 @@ public Neuralang parse(Path path) {
return this;
}

/**
* Parses Neuralang source code by handling function declarations in addition to
* other expressions.
* @param text The source code to parse.
* @return The Neuralang builder's instance.
*/
public Neuralang parse(String text) {
int depth = 0;
String progress = "";
Expand Down
6 changes: 3 additions & 3 deletions JGNN/src/main/java/mklab/JGNN/nn/ModelTraining.java
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,9 @@ public void run() {
}
public ModelTraining configFrom(ModelBuilder modelBuilder) {
setOptimizer(new Adam(modelBuilder.getConfigOrDefault("lr", 0.01)));
setEpochs(modelBuilder.getConfigOrDefault("epochs", epochs));
numBatches = modelBuilder.getConfigOrDefault("batches", numBatches);
setPatience(modelBuilder.getConfigOrDefault("patience", patience));
setEpochs((int)modelBuilder.getConfigOrDefault("epochs", epochs));
numBatches = (int)modelBuilder.getConfigOrDefault("batches", numBatches);
setPatience((int)modelBuilder.getConfigOrDefault("patience", patience));
return this;
}
}
Loading

0 comments on commit 8e50e0e

Please sign in to comment.