diff --git a/JGNN/neural_graph.jggn b/JGNN/neural_graph.jggn deleted file mode 100644 index 540bc2c..0000000 --- a/JGNN/neural_graph.jggn +++ /dev/null @@ -1,49 +0,0 @@ -mklab.JGNN.adhoc.parsers.LayeredBuilder -features = config: 2.0 -hidden = config: 8.0 -reg = config: 0.005 -classes = config: 2.0 -reduced = config: 6.0 -2hidden = config: 16.0 -hiddenReduced = config: 48.0 -A = var: null -h0 = var: null -_tmp11 = param DenseMatrix (2hidden 16, hidden 8): [-0.19328476877863923,-0.6816842687001173,0.33084330660685063,0.29012383747022974,-0.008856823759693731,0.3243209298674903,0.03630835867502023,-0.36893022633353045,0.3852823246760141,0.1675140168304929,0.34749646804722034,-0.043510359020801116,-0.5480261005662305,-0.36506518491627604,0.4618722893431745,-0.31670662910836095,-0.06766959139063297,3.964965286227935E-4,0.42952042800161827,-0.6955929308960008,-0.6286510760021554,-0.0965623432938279,-0.18299829835981102,-0.05956609007597157,-0.2920302998132937,-0.2547563636041248,0.15096554826682856,0.3775004405962659,-0.8995334446683099,0.06333771747725044,0.65460552575879,-0.3186670866531967,-0.5309119132179063,0.394784224886659,0.29509546796445657,-0.4690491776981738,-0.10572405664194967,0.27624178602476357,0.019287597474623407,-0.08221736520461895,0.3752314288377563,-0.14388855480059862,0.5031819562803397,-0.39231949747503947,0.030226728241452908,0.539230918192968,-0.1866028170582539,-0.054814542702866335,-0.037001880635509805,-0.05866626266145604,0.6570208369054049,0.4037472351461765,-0.7434327464304322,-0.1800452245270526,0.40850049372980474,0.22084528373229415,0.01669375687488634,-0.01769463510961822,0.48369486751642693,-0.4513516083571352,-0.18643307652802016,0.47531322785792024,-0.15418322605338194,0.08869303689826054,0.6496841665481397,0.12057361296404132,-0.10087202695777087,0.20077395712411783,0.12298342734788265,0.4463359109741554,0.595268214487788,0.3198579540060096,0.15003772422143247,0.3036177944924227,0.39176025176862367,0.4385176746516399,0.29267786236453575,-0.538075342791701,0.18692588008664582,-0.22874264781324213,-0.2956689694666771,0.08507352063026218,0.16405172192265624,-0.16627517419517734,0.7367038620405035,0.23692065568506798,-0.39340620243769026,-0.1218118529198394,0.34270271396790364,0.10611754538241865,0.23511719104140957,0.06257205623691434,-0.026719449424170108,0.13262598343728552,0.3321072492887023,-0.322213838336707,-0.09058229994288335,-0.03864904977375033,0.18830357531855324,0.035130801296273106,-0.40433428753295697,0.34087491252629254,0.8426791513172762,0.07823534114592788,0.017292176905372218,-0.07825012073294113,-0.551117338968398,0.5990794942026233,-0.3498034727233782,-0.15620124985299463,-0.1313686261579107,-0.17445929115062758,0.05423176395512194,0.3305972373060727,-0.07433345988400326,0.02211377023555965,0.5169316243904939,-0.35626175262079174,0.2605377618616539,0.040877884464697034,0.6011922447538958,0.10590649114094869,0.04866151867585637,1.0498653811003313,-0.4831934782954236,0.2155992503280004,-0.4449982638493458,-0.15205818353654899] -_tmp12 = param Tensor (hidden 8): [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0] -_tmp17 = param Tensor (hidden 8): [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0] -_tmp16 = param DenseMatrix (2hidden 16, hidden 8): [0.5021034187023461,0.13868567394933753,-0.4319463012253357,-0.01957776716822693,-0.6094145749292609,0.1439419292273619,-0.10279569620728086,-0.24391336506490913,-0.07060120291020784,0.365303228749883,0.46701090217314395,-0.07945034851001198,0.18256005736919872,0.523490149473091,0.3783930271864008,-0.5348058052156716,0.248710923254586,-0.18060630411319933,-0.9064084134026116,-0.5047861771258582,-0.21041070643734053,-0.2508198800475379,-0.19029786209348837,-0.34694159663184004,0.02890813945657403,0.11926584949928104,0.008029066336753142,-0.30160347629852446,0.4142657749089172,-0.10276450006376554,0.21668328634092998,0.7794994419864524,0.31912230822442206,-0.7896441157875099,0.0669323910449409,0.005079087575026865,0.3871876711708779,-0.027795613673683106,0.5195156321578849,0.1451244649129374,-0.1234044285044108,-0.3910978872449653,0.19433235371428767,-0.484786863765534,-0.06500045233351805,-0.13535973107007787,-0.001400944670307145,0.14315247913721696,-0.44489526404547725,-0.7113623734549364,-0.3537344028937658,-0.2826613161956493,0.38693511145728415,-0.5540629852227372,-0.26564026626370785,-0.7528758353657103,-0.5030481946916305,0.7361802029086209,0.048380119415574255,0.19193865286784526,-0.18572427292244453,-0.1070739962101043,-0.4667238425266632,0.1724481017787257,0.07907853007683265,0.07794419563775512,-0.2258137337359118,0.7440959006326245,-0.684988546540687,0.17920646820148134,0.46240699117307765,0.4320721338608193,-0.37346377979218376,-0.4607102416897734,-0.5562339107003924,-1.1335825147833678,-0.6780782600923088,0.3572746140653733,-0.41292103314568673,-0.31905109650720387,-0.25370387856853094,-0.480925487249198,-0.3151921264430039,0.07776796917965208,-0.1788986631998994,0.08475066278957165,-0.5237351973530295,0.4193834593785406,0.03681254346412109,-0.584427852645633,-0.02896556963429372,-0.1299053244734533,-0.2218704270185149,0.24571193054996324,0.42048425749103463,-0.8880484715017992,0.31700958719137495,-0.4581902441160126,-0.30034222408803624,-0.7713140669541653,0.1906079322091254,0.5614853908408053,-0.12311847026365601,0.094618305065478,0.16756886278939304,0.5190453004025772,0.1461756848410729,0.10340340085403356,0.009543912257454283,-0.4945736675559478,0.18117986153395202,0.4416916735062306,0.2303877229303233,0.22360404940516052,-0.6112281541522248,-0.3049268148884489,-0.21320808341784314,0.23185142098280917,0.4502900379269799,-0.165791378027787,0.7864019252894919,0.2813364624331278,0.9474109407516412,0.12430889772546948,-0.2172150595120388,0.4876508580337231,0.344365758524931,0.36338942310374867] -_tmp20 = param DenseMatrix (hidden 8, hidden 8): [0.6814404680132916,0.25361885765312314,0.10256163018756725,0.1922851650348111,-0.05410077006150943,-0.38549110279625004,0.5004202987313823,-0.2324453468875155,-0.22099505283361742,0.06551162436227614,0.5271888950214714,0.22805732984591637,0.42595558449507714,-0.2595469927668719,0.41457441383114,-0.40693167679055886,0.39803040041408844,0.3440959433142813,-0.5753750505438064,0.06237296287559791,-0.4211115226435094,-0.1763906243200833,1.1985422471982727,-0.07676527261740763,0.2326178151531441,0.35704961284054426,-0.26729485975443884,0.4095283664505398,0.18168483029293744,-0.2932381742502261,0.8414330706272171,0.37977048763315235,0.5062335030466366,0.9869523173389951,1.6838631648078104,0.1263031934267131,-0.5478652955503843,-3.035840047741802E-4,-0.3174070977338061,0.7191699495163361,-0.07943020358021734,0.21984277828430565,0.8033588782047216,-0.1861673333583821,0.5596593401061914,-0.35031592936519707,-0.37717583187856096,0.23602515813879557,0.0449456664499197,-0.47412881293932263,0.21589515054272587,-0.20568191860829724,0.2901169807397064,0.4150953875668683,0.03210280126721354,-0.7381724464055942,0.08744218760930367,1.657249410635567,0.38200285347323787,0.41650869362747445,0.12939056231065268,0.40376690602764725,-0.047337732760709926,0.6966554415998744] -_tmp21 = param Tensor (hidden 8): [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0] -_tmp3 = param Tensor (hidden 8): [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0] -_tmp2 = param DenseMatrix (features 2, hidden 8): [0.0962596204088395,-0.028884691451251224,-0.028116488038815862,0.3577563699462411,-0.1915495424716903,-0.5879573411284648,0.16687000433532928,0.18406857056345563,-0.5892489716766497,0.9247104404685634,-0.06459238390322171,-1.1421568977936178,-0.722566667961998,-1.3796258912305088,-0.06280695024528939,1.7517702074586412] -_tmp23 = param DenseMatrix (hiddenReduced 48, classes 2): [0.40938433617236525,-0.4987199441662107,0.26203992378101376,-0.43133379215734263,0.6709105283816391,-0.9814661346740309,0.539095205750881,-0.4207697065158933,0.0240229334618446,-0.14223548752189802,0.4438842753898452,-0.49120390072985726,0.17395173314865078,0.06471401007521176,0.4478543276829584,0.43432759563978957,0.3447420155314258,-1.1253578954761723,0.4993227096710841,0.30042014773726416,0.6795526379343243,-0.5195903531075249,0.4621000372201286,-0.7689859292393859,0.9106708726562932,-0.6086682072036363,-0.49443734786718163,-2.0015702142041314,0.2759331578478036,-0.9644772192516243,-0.14250917516719866,-5.04944040825525,1.8578946969831234,-2.689873422142453,17.819588902507228,2.3524952524944034,0.607415470001217,-0.005342721841655263,0.2345666718277672,-0.6435856594935998,0.09465516271962526,-1.028924483552736,0.5494581031684685,0.004150240845003633,0.09245869079258418,-1.159928537758163,0.5395232430817755,-0.8876905496583974,-0.43672870130761726,-0.6001793288795612,-0.2209758164529676,0.3771985686838059,-0.5264231473231888,1.2086668405130667,-0.3677004028561642,-0.23961520350564466,-0.1312633323701847,0.20028985803474514,-0.5997756663497473,1.1861726582546483,-0.5171240051903958,-0.36876941825191734,-0.08069918314199832,-0.6008323309697636,0.2923273123036548,1.4251798900771944,-0.3047689840101401,-0.11336172298411515,-0.5004652657880019,-0.25323032177253096,-0.04182412745906058,0.7914127746669496,-0.7185649111348198,0.4124581236471469,0.8006830216814292,1.4727565356504557,-0.98152228243357,0.6496030752864841,0.2086967713555983,5.221403415918512,-2.005888671728391,2.2603888220693245,-17.575851034929514,-2.500026725809498,-0.24434030852837782,-0.33232526683552893,0.3851370120960596,1.4858460025078009,-0.421678210018014,0.7016615436437713,-0.578971814276199,0.029161238557184666,-0.06956312718637947,1.0065065640325745,-0.5451395936226792,1.0762157054167099] -_tmp5 = param DenseMatrix (hidden 8, hidden 8): [-0.30954935594367816,0.12116378595232227,-0.306386786291067,-0.5220864721581343,-0.003000295537359033,-0.37628978694357235,-0.2695812643245417,0.6700734618145359,0.20645751601752704,-0.05260439493400403,0.17754381848591488,-0.17735111718384539,0.167779299276608,0.27389750680988467,0.23764308545771937,1.6708105311783883,0.06728854514002967,-0.44839216787805547,0.13632607512869982,0.37582277585252005,-0.28216295775529515,-0.5678759903551267,-0.35640357469197825,0.3560222333683681,-0.06393612839774328,-0.43344284383245607,-0.3986854818366697,0.26732202379329006,0.34858445673548166,0.01968688359473008,0.7123181505558769,-0.6411211857602117,-0.271134106235963,0.2491391197457063,0.0939090724621519,-0.7099186482668389,0.4388952099959059,0.8554823355249708,-1.166165039369413,1.1988479095710083,-0.2320232541561487,-0.08166507448918221,-0.41733625408668373,-0.6475038987362935,-0.4785141909120537,-0.6513421097528718,0.7196243595671512,0.0627929089576402,0.08109030285230427,-2.3186007809859356E-4,0.06453881213468693,0.7580057257515331,-0.12080130803301743,-0.17046938190024785,0.28216868068309153,0.005635712488808225,-0.45095495159758453,0.13584676111699664,-0.73411011182003,-0.22715114851512125,-0.5224650791751124,0.8293987840564809,0.021505674769915967,-0.2903500417602574] -_tmp6 = param Tensor (hidden 8): [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0] -edgeSrc = from A -edgeDst = to A -_tmp1 = h0 @ _tmp2 -_tmp0 = _tmp1 + _tmp3 -h1 = relu _tmp0 -_tmp4 = h1 @ _tmp5 -h2 = _tmp4 + _tmp6 -_tmp7 = h2 [ edgeSrc ] -_tmp8 = h2 [ edgeDst ] -message2 = _tmp7 | _tmp8 -_tmp10 = message2 @ _tmp11 -_tmp9 = _tmp10 + _tmp12 -transformed2 = relu _tmp9 -received2 = reduce ( transformed2 , A ) -_tmp14 = received2 | h2 -_tmp15 = _tmp14 @ _tmp16 -_tmp13 = _tmp15 + _tmp17 -i2 = relu _tmp13 -_tmp19 = i2 @ _tmp20 -_tmp18 = _tmp19 + _tmp21 -h3 = relu _tmp18 -z3 = sort ( h3 , reduced ) -_tmp22 = h3 [ z3 ] -h4 = reshape ( _tmp22 , 1 , hiddenReduced ) -h5 = h4 @ _tmp23 -h6 = softmax ( h5 , row ) - -return h6 diff --git a/JGNN/src/examples/graphClassification/GCNlang.java b/JGNN/src/examples/graphClassification/GCNlang.java deleted file mode 100644 index 51383b1..0000000 --- a/JGNN/src/examples/graphClassification/GCNlang.java +++ /dev/null @@ -1,108 +0,0 @@ -package graphClassification; - -import java.util.Arrays; - -import mklab.JGNN.adhoc.Dataset; -import mklab.JGNN.adhoc.ModelBuilder; -import mklab.JGNN.adhoc.datasets.Cora; -import mklab.JGNN.adhoc.parsers.FastBuilder; -import mklab.JGNN.adhoc.parsers.LayeredBuilder; -import mklab.JGNN.core.Matrix; -import mklab.JGNN.nn.Loss; -import mklab.JGNN.nn.Model; -import mklab.JGNN.nn.ModelTraining; -import mklab.JGNN.core.Slice; -import mklab.JGNN.core.Tensor; -import mklab.JGNN.core.ThreadPool; -import mklab.JGNN.nn.initializers.XavierNormal; -import mklab.JGNN.nn.loss.CategoricalCrossEntropy; -import mklab.JGNN.nn.optimizers.Adam; -import mklab.JGNN.nn.optimizers.BatchOptimizer; - -/** - * Demonstrates classification with the GCN architecture. - * - * @author Emmanouil Krasanakis - */ -public class GCNlang { - - - public static void main(String[] args){ - long reduced = 5; // input graphs need to have at least that many nodes, lower values decrease accuracy - long hidden = 8; // since this library does not use GPU parallelization, many latent dims reduce speed - - - Dataset dataset = new Cora(); - - ModelBuilder modelBuilder = new LangBuilder() - .config("features", dataset.features().getCols()) - .config("classes", 0.005) - .parse("fn gcnlayer(A, h, hidden=64, reg=0.005) {" - + "return relu(A@h@matrix(?, hidden, reg) + vector(hidden));" - + "}" - + "fn gcn(A, h) {" - + " h = gcnlayer(A, h);" - + " h = gcnlayer(A, h, hidden=classes);" - + " return h;\r\n" - + "}") - .assertBackwardValidity(); - - - /*ModelBuilder builder = new LayeredBuilder() - .var("A") - .config("features", 1) - .config("classes", 2) - .config("reduced", reduced) - .config("hidden", hidden) - .config("reg", 0.005) - .layer("h{l+1}=relu(A@(h{l}@matrix(features, hidden, reg))+vector(hidden))") - .layer("h{l+1}=relu(A@(h{l}@matrix(hidden, hidden, reg))+vector(hidden))") - .concat(2) // concatenates the outputs of the last 2 layers - .config("hiddenReduced", hidden*2*reduced) // 2* due to concatenation - .operation("z{l}=sort(h{l}, reduced)") // currently, the parser fails to understand full expressions within next step's gather, so we need to create this intermediate variable - .layer("h{l+1}=reshape(h{l}[z{l}], 1, hiddenReduced)") // - .layer("h{l+1}=h{l}@matrix(hiddenReduced, classes)") - .layer("h{l+1}=softmax(h{l}, row)") - //.layer("h{l+1}=softmax(sum(h{l}@matrix(hiddenReduced, classes), row))")//this is mean pooling to replace the above sort pooling - .out("h{l}"); */ - - TrajectoryData dtrain = new TrajectoryData(8000); - TrajectoryData dtest = new TrajectoryData(2000); - - Model model = builder.getModel().init(new XavierNormal()); - BatchOptimizer optimizer = new BatchOptimizer(new Adam(0.01)); - Loss loss = new CategoricalCrossEntropy(); - for(int epoch=0; epoch<600; epoch++) { - // gradient update over all graphs - for(int graphId=0; graphId temp = new HashMap(); + HashMap renameLater = new HashMap(); + for(int i=0;i