Skip to content

Commit

Permalink
optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
maniospas committed May 28, 2024
1 parent 0dddcc2 commit cc092e2
Show file tree
Hide file tree
Showing 69 changed files with 204 additions and 75 deletions.
Binary file modified JGNN.jar
Binary file not shown.
2 changes: 1 addition & 1 deletion JGNN/.classpath
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
</attributes>
</classpathentry>
<classpathentry kind="src" path="src/examples"/>
<classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER/org.eclipse.jdt.internal.debug.ui.launcher.StandardVMType/JavaSE-1.8">
<classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER/org.eclipse.jdt.internal.debug.ui.launcher.StandardVMType/JavaSE-17">
<attributes>
<attribute name="module" value="true"/>
<attribute name="maven.pomderived" value="true"/>
Expand Down
10 changes: 5 additions & 5 deletions JGNN/.settings/org.eclipse.jdt.core.prefs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
eclipse.preferences.version=1
org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode=enabled
org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.8
org.eclipse.jdt.core.compiler.codegen.targetPlatform=17
org.eclipse.jdt.core.compiler.codegen.unusedLocal=preserve
org.eclipse.jdt.core.compiler.compliance=1.8
org.eclipse.jdt.core.compiler.compliance=17
org.eclipse.jdt.core.compiler.debug.lineNumber=generate
org.eclipse.jdt.core.compiler.debug.localVariable=generate
org.eclipse.jdt.core.compiler.debug.sourceFile=generate
org.eclipse.jdt.core.compiler.problem.assertIdentifier=error
org.eclipse.jdt.core.compiler.problem.enablePreviewFeatures=disabled
org.eclipse.jdt.core.compiler.problem.enumIdentifier=error
org.eclipse.jdt.core.compiler.problem.forbiddenReference=warning
org.eclipse.jdt.core.compiler.problem.reportPreviewFeatures=ignore
org.eclipse.jdt.core.compiler.release=disabled
org.eclipse.jdt.core.compiler.source=1.8
org.eclipse.jdt.core.compiler.problem.reportPreviewFeatures=warning
org.eclipse.jdt.core.compiler.release=enabled
org.eclipse.jdt.core.compiler.source=17
7 changes: 7 additions & 0 deletions JGNN/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,12 @@
<artifactId>junit</artifactId>
<version>4.13.1</version>
</dependency>
<dependency>
<groupId>it.unimi.dsi</groupId>
<artifactId>fastutil-core</artifactId>
<version>8.5.11</version>
</dependency>


</dependencies>
</project>
6 changes: 3 additions & 3 deletions JGNN/src/examples/nodeClassification/GCN.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import mklab.JGNN.nn.ModelTraining;
import mklab.JGNN.core.Slice;
import mklab.JGNN.core.Tensor;
import mklab.JGNN.core.empy.EmptyMatrix;
import mklab.JGNN.core.empy.EmptyTensor;
import mklab.JGNN.nn.initializers.XavierNormal;
import mklab.JGNN.nn.loss.CategoricalCrossEntropy;
Expand All @@ -35,14 +34,15 @@ public static void main(String[] args) throws Exception {
.config("hidden", numClasses)
.function("gcnlayer", "(A,h){z=dropout(A, 0.5)@(h@matrix(?, hidden, reg))+vector(?);return z}")
.layer("h{l+1}=relu(gcnlayer(A, h{l}))")
.config("hidden", "classes")
.layer("h{l+1}=gcnlayer(A, h{l})")
.classify()
.autosize(new EmptyTensor(dataset.samples().getSlice().size()));

ModelTraining trainer = new ModelTraining()
.setOptimizer(new Adam(0.01))
.setEpochs(300)
.setPatience(300)
.setEpochs(10)
.setPatience(100)
.setVerbose(true)
.setLoss(new CategoricalCrossEntropy())
.setValidationLoss(new CategoricalCrossEntropy());
Expand Down
5 changes: 4 additions & 1 deletion JGNN/src/examples/nodeClassification/Scripting.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import mklab.JGNN.core.empy.EmptyTensor;
import mklab.JGNN.nn.initializers.XavierNormal;
import mklab.JGNN.nn.loss.CategoricalCrossEntropy;
import mklab.JGNN.nn.optimizers.Adam;

/**
* Demonstrates classification with an architecture defined through the scripting engine.
Expand All @@ -36,7 +37,9 @@ public static void main(String[] args) throws Exception {
.autosize(new EmptyTensor(dataset.samples().getSlice().size()));

ModelTraining trainer = new ModelTraining()
.configFrom(modelBuilder)
.setEpochs(300)
.setPatience(100)
.setOptimizer(new Adam(0.01))
.setVerbose(true)
.setLoss(new CategoricalCrossEntropy())
.setValidationLoss(new CategoricalCrossEntropy());
Expand Down
93 changes: 64 additions & 29 deletions JGNN/src/main/java/mklab/JGNN/adhoc/ModelBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,15 @@ public ModelBuilder config(String name, double value) {
return this;
}


public ModelBuilder config(String name, String value) {
Double val = configurations.get(value);
if(val==null)
throw new RuntimeException("No configuration "+value+" found");
this.configurations.put(name, val);
return this;
}

public int getConfigOrDefault(String name, int defaultValue) {
return (int)(double)configurations.getOrDefault(name, (double) defaultValue);
}
Expand Down Expand Up @@ -489,7 +498,7 @@ public ModelBuilder function(String name, String value) {
}

private static List<String> extractTokens(String input) {
String tokenRegex = "\\b[a-zA-Z_][a-zA-Z0-9_]*\\b"+"|\\b\\w+\\b|\\(|\\)|\\=|\\+|\\;|\\!|\\:|\\#|\\-|\\.|\\*|\\@|\\/|\\[|\\]|\\,|\\?|\\||\\{|\\}";
String tokenRegex = "\\b[a-zA-Z_][a-zA-Z0-9_]*\\b"+"|(\\\".*\\\")"+"|\\b\\w+\\b|\\(|\\)|\\=|\\+|\\;|\\!|\\:|\\#|\\-|\\.|\\*|\\@|\\/|\\[|\\]|\\,|\\?|\\||\\{|\\}";
Pattern tokenPattern = Pattern.compile(tokenRegex);
Matcher tokenMatcher = tokenPattern.matcher(input);
List<String> tokens = new ArrayList<>();
Expand All @@ -499,7 +508,9 @@ private static List<String> extractTokens(String input) {
}
return tokens;
}

private static boolean isString(String str) {
return str.startsWith("\"") && str.endsWith("\"");
}

private static boolean isNumeric(String str) {
if (str == null || str.isEmpty()) {
Expand Down Expand Up @@ -586,7 +597,7 @@ else if(c==')') {
|| (newDesc.endsWith(" vector ") && isDouble(arg))
|| (newDesc.endsWith(" sort ") && isDouble(arg))
|| (newDesc.endsWith(" reshape ") && isDouble(arg))
|| arg.equals("col") || arg.equals("row"))
|| arg.equals("col") || arg.equals("row") || arg.contains(":"))
args += arg;
else {
String tmpName = "_tmp"+tmpVariableIdentifier;
Expand Down Expand Up @@ -701,12 +712,19 @@ else if(splt[2].equals("softmax")) {
boolean mode = false;
if(splt.length>4) {
String modeText = splt[4].trim();
if(modeText.equals("col"))
if(splt.length>5)
modeText = splt[4]+splt[5];
if(modeText.endsWith(";"))
modeText = modeText.substring(0, modeText.length()-1);
if(!modeText.split("\\:")[0].trim().equals("dim"))
throw new RuntimeException("Second argument "+modeText+" to softmax should be a dim config (dim: \"row\" or dim: \"col\")");
modeText = modeText.substring(modeText.indexOf(":")+1).trim();
if(modeText.equals("\"col\""))
mode = false;
else if(modeText.equals("row"))
else if(modeText.equals("\"row\""))
mode = true;
else
throw new RuntimeException("Invalid argument "+modeText+" to softmax");
throw new RuntimeException("Invalid dim argument "+modeText+" to softmax");
}
component = new SoftMax(mode);
arg0 = splt[3];
Expand All @@ -715,12 +733,15 @@ else if(splt[2].equals("sum")) {
boolean mode = false;
if(splt.length>4) {
String modeText = splt[4].trim();
if(modeText.equals("col"))
if(!modeText.split("\\:")[0].trim().equals("dim"))
throw new RuntimeException("Second argument "+modeText+" to sum should be a dim config (dim: \"row\" or dim: \"col\")");
modeText = modeText.substring(modeText.indexOf(":")+1).trim();
if(modeText.equals("\"col\""))
mode = false;
else if(modeText.equals("row"))
else if(modeText.equals("\"row\""))
mode = true;
else
throw new RuntimeException("Invalid argument "+modeText+" to softmax");
throw new RuntimeException("Invalid dim argument "+modeText+" to sum");
}
component = new Sum(mode);
arg0 = splt[3];
Expand All @@ -729,12 +750,15 @@ else if(splt[2].equals("mean")) {
boolean mode = false;
if(splt.length>4) {
String modeText = splt[4].trim();
if(modeText.equals("col"))
if(!modeText.split("\\:")[0].trim().equals("dim"))
throw new RuntimeException("Second argument "+modeText+" to mean should be a dim config (dim: \"row\" or dim: \"col\")");
modeText = modeText.substring(modeText.indexOf(":")+1).trim();
if(modeText.equals("\"col\""))
mode = false;
else if(modeText.equals("row"))
else if(modeText.equals("\"row\""))
mode = true;
else
throw new RuntimeException("Invalid argument "+modeText+" to softmax");
throw new RuntimeException("Invalid dim argument "+modeText+" to mean");
}
component = new Mean(mode);
arg0 = splt[3];
Expand All @@ -743,12 +767,15 @@ else if(splt[2].equals("max")) {
boolean mode = false;
if(splt.length>4) {
String modeText = splt[4].trim();
if(modeText.equals("col"))
if(!modeText.split("\\:")[0].trim().equals("dim"))
throw new RuntimeException("Second argument "+modeText+" to max should be a dim config (dim: \"row\" or dim: \"col\")");
modeText = modeText.substring(modeText.indexOf(":")+1).trim();
if(modeText.equals("\"col\""))
mode = false;
else if(modeText.equals("row"))
else if(modeText.equals("\"row\""))
mode = true;
else
throw new RuntimeException("Invalid argument "+modeText+" to softmax");
throw new RuntimeException("Invalid dim argument "+modeText+" to max");
}
component = new Max(mode);
arg0 = splt[3];
Expand Down Expand Up @@ -907,20 +934,22 @@ else if(functions.containsKey(splt[2])) {
HashMap<String, Double> configStack = this.configurations;
this.configurations = new HashMap<String, Double>(this.configurations);
HashMap<String, String> customNames = new HashMap<String, String>();
for(int i=0;i<args.length;i++)
if(!args[i].contains(":"))
for(int i=0;i<Math.max(args.length, splt.length-3);i++)
if(!args[i].contains(":") && !splt[i+3].contains(":"))
customNames.put(args[i].trim(), splt[i+3]);
else {
String config = args[i].substring(0, args[i].indexOf(":")).trim();
String value = args[i].substring(args[i].indexOf(":")+1).trim();
if(value.equals("extern")) {
if(i<args.length) {
String config = args[i].substring(0, args[i].indexOf(":")).trim();
String value = args[i].substring(args[i].indexOf(":")+1).trim();
if(value.equals("extern")) {
if(!this.configurations.containsKey(config))
throw new RuntimeException("Required external config: "+config);
}
if(!this.configurations.containsKey(config))
throw new RuntimeException("Required external config: "+config);
this.config(config, parseConfigValue(value));
}
if(!this.configurations.containsKey(config))
this.config(config, parseConfigValue(value));
/*// these are parsed in the attempt to create an intermediate variable for the argument
if(i<splt.length-3) {
// these are parsed in the attempt to create an intermediate variable for the argument
if(i<splt.length-3) {
String config = splt[i+3].substring(0, splt[i+3].indexOf(":")).trim();
String value = splt[i+3].substring(splt[i+3].indexOf(":")+1).trim();
if(value.equals("extern")) {
Expand All @@ -929,17 +958,17 @@ else if(functions.containsKey(splt[2])) {
}
else
this.config(config, parseConfigValue(value));
}*/
}
}
List<String> tokens = extractTokens(functions.get(splt[2]));
HashSet<String> keywords = new HashSet<String>();
keywords.addAll(functions.keySet());
keywords.addAll(Arrays.asList(".", "+", "-", "*", "/", "@", ",", "(", ")", ";", "=",
keywords.addAll(Arrays.asList(".", "+", "-", "*", "/", "@", ",", "(", ")", ";", "=", "\"",
"max", "min", "vector", "matrix", "vec", "mat", "[", "]", "{", "}", "|", "#", "!", ":",
"extern", "softmax",
"from", "to", "reduce", "transpose", "attention", "att", "dropout", "drop",
"repeat", "exp", "nexp", "L1", "sigmoid", "transpose", "monitor",
"log", "tanh", "prelu", "lrelu", "relu", "reshape", "mean", "col", "row"));
"log", "tanh", "prelu", "lrelu", "relu", "reshape", "mean"));
keywords.addAll(this.components.keySet());
keywords.addAll(this.configurations.keySet());
customNames.put("return", splt[0]+" = ");
Expand All @@ -957,9 +986,15 @@ else if(functions.containsKey(splt[2])) {
renameLater.put(token, "_"+splt[2]+functionRepetition+"_stack"+id+"_"+token);
token = "_"+splt[2]+functionRepetition+"_stack"+id+"_"+token;
}
else if(i<tokens.size()-1 && tokens.get(i+1).equals(":")) {
// token = token; // this is the case where we have configuration arguments
}
else if(i>0 && tokens.get(i-1).equals(":")) {
// token = token; // this is the case where we have configuration values
}
else if(customNames.containsKey(token))
token = customNames.get(token);
else if(!keywords.contains(token) && !isNumeric(token) && !prevHash)
else if(!keywords.contains(token) && !isNumeric(token) && !isString(token) && !prevHash)
token = "_"+splt[2]+functionRepetition+"_"+token;
prevHash = token.equals("#");
prevTemp = token.equals("!");
Expand Down
6 changes: 5 additions & 1 deletion JGNN/src/main/java/mklab/JGNN/adhoc/parsers/FastBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public FastBuilder layer(String expression) {
public FastBuilder classify() {
var("nodes");
layer("h{l+1}=h{l}[nodes]");
layer("h{l+1}=softmax(h{l}, row)");
layer("h{l+1}=softmax(h{l}, dim: \"row\")");
out("h"+layer);
return this;
}
Expand All @@ -113,6 +113,10 @@ public FastBuilder function(String name, String value) {
super.function(name, value);
return this;
}
public FastBuilder config(String name, String value) {
super.config(name, value);
return this;
}
public FastBuilder config(String name, double value) {
super.config(name, value);
return this;
Expand Down
6 changes: 4 additions & 2 deletions JGNN/src/main/java/mklab/JGNN/core/Matrix.java
Original file line number Diff line number Diff line change
Expand Up @@ -304,18 +304,20 @@ public void run() {
}
else {
if(estimateNumNonZeroElements()/getRows()<with.estimateNumNonZeroElements()/with.getCols()) {
final long withCols = with.getCols();
for(Entry<Long, Long> element : getNonZeroEntries()) {
long row = element.getKey();
long col = element.getValue();
for(long col2=0;col2<with.getCols();col2++)
for(long col2=0;col2<withCols;++col2)
ret.put(row, col2, ret.get(row, col2) + get(row, col)*with.get(col, col2));
}
}
else {
final long rows = getRows();
for(Entry<Long, Long> element : with.getNonZeroEntries()) {
long row = element.getKey();
long col = element.getValue();
for(long row1=0;row1<getRows();row1++)
for(long row1=0;row1<rows;++row1)
ret.put(row1, col, ret.get(row1, col) + get(row1, row)*with.get(row, col));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import mklab.JGNN.core.Matrix;
import mklab.JGNN.core.Tensor;
import mklab.JGNN.core.util.FastEntry;

/**
* Defines a matrix whose columns are all a copy of a {@link Tensor}.
Expand Down Expand Up @@ -45,6 +46,7 @@ public Iterator<Long> iterator() {
protected class Repeat2DIterator implements Iterator<Entry<Long, Long>>, Iterable<Entry<Long, Long>> {
private Iterator<Long> iterator;
private long current;
private final FastEntry<Long,Long> ret = new FastEntry<Long, Long>();
public Repeat2DIterator() {
this.iterator = column.iterator();
current = 0;
Expand All @@ -60,7 +62,11 @@ public Entry<Long, Long> next() {
iterator = column.iterator();
}
long pos = iterator.next();
return new AbstractMap.SimpleEntry<Long,Long>(Long.valueOf(pos), Long.valueOf(current));

ret.setKey(pos);
ret.setValue(current);
return ret;
//return new AbstractMap.SimpleEntry<Long,Long>(Long.valueOf(pos), Long.valueOf(current));
}
@Override
public Iterator<Entry<Long, Long>> iterator() {
Expand Down
7 changes: 6 additions & 1 deletion JGNN/src/main/java/mklab/JGNN/core/matrix/Diagonal.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import java.util.Map.Entry;
import mklab.JGNN.core.Matrix;
import mklab.JGNN.core.Tensor;
import mklab.JGNN.core.util.FastEntry;

/**
* Implements a square matrix whose diagonal elements are determined by the correspond values of
Expand Down Expand Up @@ -37,6 +38,7 @@ public Iterator<Long> iterator() {

protected class Diagonal2DIterator implements Iterator<Entry<Long, Long>>, Iterable<Entry<Long, Long>> {
private Iterator<Long> iterator;
private final FastEntry<Long,Long> ret = new FastEntry<Long, Long>();
public Diagonal2DIterator(Iterator<Long> iterator) {
this.iterator = iterator;
}
Expand All @@ -47,7 +49,10 @@ public boolean hasNext() {
@Override
public Entry<Long, Long> next() {
long pos = iterator.next();
return new AbstractMap.SimpleEntry<Long,Long>(Long.valueOf(pos), Long.valueOf(pos));
ret.setKey(pos);
ret.setValue(pos);
return ret;
//return new AbstractMap.SimpleEntry<Long,Long>(Long.valueOf(pos), Long.valueOf(pos));
}
@Override
public Iterator<Entry<Long, Long>> iterator() {
Expand Down
Loading

0 comments on commit cc092e2

Please sign in to comment.