From 570b56ef167bbf350a7a34a7126170eb334575cf Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Sun, 10 Dec 2023 11:50:48 +0200 Subject: [PATCH] Refactored integration testing data format selection --- .../xgboost/testing/XGBoostEncoderBatch.java | 20 +++++++++++-- .../jpmml/xgboost/testing/XGBoostFormats.java | 26 +++++++++++++++++ .../xgboost/testing/ClassificationTest.java | 11 +++++++- .../jpmml/xgboost/testing/RegressionTest.java | 28 ++++++++++++++++++- 4 files changed, 81 insertions(+), 4 deletions(-) create mode 100644 pmml-xgboost/src/main/java/org/jpmml/xgboost/testing/XGBoostFormats.java diff --git a/pmml-xgboost/src/main/java/org/jpmml/xgboost/testing/XGBoostEncoderBatch.java b/pmml-xgboost/src/main/java/org/jpmml/xgboost/testing/XGBoostEncoderBatch.java index 3b9d40d..1e7df49 100644 --- a/pmml-xgboost/src/main/java/org/jpmml/xgboost/testing/XGBoostEncoderBatch.java +++ b/pmml-xgboost/src/main/java/org/jpmml/xgboost/testing/XGBoostEncoderBatch.java @@ -22,6 +22,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.function.Predicate; import com.google.common.base.Equivalence; @@ -40,8 +41,15 @@ abstract public class XGBoostEncoderBatch extends ModelEncoderBatch { - private String format = System.getProperty(XGBoostEncoderBatch.class.getName() + ".format", "model,json,ubj"); + private String[] formats = {XGBoostFormats.BINARY, XGBoostFormats.JSON, XGBoostFormats.UBJSON}; + { + String format = System.getProperty(XGBoostEncoderBatch.class.getName() + ".format", null); + + if(format != null){ + this.formats = format.split(","); + } + } public XGBoostEncoderBatch(String algorithm, String dataset, Predicate columnFilter, Equivalence equivalence){ super(algorithm, dataset, columnFilter, equivalence); @@ -83,7 +91,7 @@ public String getFeatureMapPath(){ public PMML getPMML() throws Exception { PMML result = null; - String[] formats = this.format.split(","); + String[] formats = getFormats(); for(String format : formats){ PMML pmml = loadPMML(getLearnerPath(format), getFeatureMapPath()); @@ -107,6 +115,14 @@ public String getOutputCsvPath(){ return super.getOutputCsvPath(); } + public String[] getFormats(){ + return this.formats; + } + + public void setFormats(String[] formats){ + this.formats = Objects.requireNonNull(formats); + } + protected PMML loadPMML(String learnerPath, String featureMapPath) throws Exception { Learner learner; diff --git a/pmml-xgboost/src/main/java/org/jpmml/xgboost/testing/XGBoostFormats.java b/pmml-xgboost/src/main/java/org/jpmml/xgboost/testing/XGBoostFormats.java new file mode 100644 index 0000000..f785f19 --- /dev/null +++ b/pmml-xgboost/src/main/java/org/jpmml/xgboost/testing/XGBoostFormats.java @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2023 Villu Ruusmann + * + * This file is part of JPMML-XGBoost + * + * JPMML-XGBoost is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * JPMML-XGBoost is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with JPMML-XGBoost. If not, see . + */ +package org.jpmml.xgboost.testing; + +public interface XGBoostFormats { + + String BINARY = "model"; + String JSON = "json"; + String UBJSON = "ubj"; +} \ No newline at end of file diff --git a/pmml-xgboost/src/test/java/org/jpmml/xgboost/testing/ClassificationTest.java b/pmml-xgboost/src/test/java/org/jpmml/xgboost/testing/ClassificationTest.java index f2fe13f..154ca54 100644 --- a/pmml-xgboost/src/test/java/org/jpmml/xgboost/testing/ClassificationTest.java +++ b/pmml-xgboost/src/test/java/org/jpmml/xgboost/testing/ClassificationTest.java @@ -29,7 +29,7 @@ import org.jpmml.evaluator.testing.FloatEquivalence; import org.junit.Test; -public class ClassificationTest extends XGBoostEncoderBatchTest implements XGBoostAlgorithms, XGBoostDatasets, Fields { +public class ClassificationTest extends XGBoostEncoderBatchTest implements XGBoostAlgorithms, XGBoostDatasets, XGBoostFormats, Fields { public ClassificationTest(){ super(new FloatEquivalence(4)); @@ -39,6 +39,15 @@ public ClassificationTest(){ public XGBoostEncoderBatch createBatch(String algorithm, String dataset, Predicate columnFilter, Equivalence equivalence){ XGBoostEncoderBatch result = new XGBoostEncoderBatch(algorithm, dataset, columnFilter, equivalence){ + { + String dataset = getDataset(); + + // XXX + if(dataset.startsWith(AUDIT)){ + setFormats(new String[]{JSON, UBJSON}); + } + } + @Override public ClassificationTest getArchiveBatchTest(){ return ClassificationTest.this; diff --git a/pmml-xgboost/src/test/java/org/jpmml/xgboost/testing/RegressionTest.java b/pmml-xgboost/src/test/java/org/jpmml/xgboost/testing/RegressionTest.java index 31e7c63..b36aa7d 100644 --- a/pmml-xgboost/src/test/java/org/jpmml/xgboost/testing/RegressionTest.java +++ b/pmml-xgboost/src/test/java/org/jpmml/xgboost/testing/RegressionTest.java @@ -18,15 +18,41 @@ */ package org.jpmml.xgboost.testing; +import java.util.function.Predicate; + +import com.google.common.base.Equivalence; +import org.jpmml.evaluator.ResultField; import org.jpmml.evaluator.testing.FloatEquivalence; import org.junit.Test; -public class RegressionTest extends XGBoostEncoderBatchTest implements XGBoostAlgorithms, XGBoostDatasets { +public class RegressionTest extends XGBoostEncoderBatchTest implements XGBoostAlgorithms, XGBoostDatasets, XGBoostFormats { public RegressionTest(){ super(new FloatEquivalence(4)); } + @Override + public XGBoostEncoderBatch createBatch(String algorithm, String dataset, Predicate columnFilter, Equivalence equivalence){ + XGBoostEncoderBatch result = new XGBoostEncoderBatch(algorithm, dataset, columnFilter, equivalence){ + + { + String dataset = getDataset(); + + // XXX + if(dataset.startsWith(AUDIT) || dataset.startsWith(AUTO) || dataset.startsWith(VISIT)){ + setFormats(new String[]{JSON, UBJSON}); + } + } + + @Override + public RegressionTest getArchiveBatchTest(){ + return RegressionTest.this; + } + }; + + return result; + } + @Test public void evaluateLinearAuto() throws Exception { evaluate(LINEAR_REGRESSION, AUTO, new FloatEquivalence(8));