Skip to content

Commit

Permalink
Refactored integration testing data format selection
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Dec 10, 2023
1 parent cdadd4e commit 570b56e
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Predicate;

import com.google.common.base.Equivalence;
Expand All @@ -40,8 +41,15 @@
abstract
public class XGBoostEncoderBatch extends ModelEncoderBatch {

private String format = System.getProperty(XGBoostEncoderBatch.class.getName() + ".format", "model,json,ubj");
private String[] formats = {XGBoostFormats.BINARY, XGBoostFormats.JSON, XGBoostFormats.UBJSON};

{
String format = System.getProperty(XGBoostEncoderBatch.class.getName() + ".format", null);

if(format != null){
this.formats = format.split(",");
}
}

public XGBoostEncoderBatch(String algorithm, String dataset, Predicate<ResultField> columnFilter, Equivalence<Object> equivalence){
super(algorithm, dataset, columnFilter, equivalence);
Expand Down Expand Up @@ -83,7 +91,7 @@ public String getFeatureMapPath(){
public PMML getPMML() throws Exception {
PMML result = null;

String[] formats = this.format.split(",");
String[] formats = getFormats();
for(String format : formats){
PMML pmml = loadPMML(getLearnerPath(format), getFeatureMapPath());

Expand All @@ -107,6 +115,14 @@ public String getOutputCsvPath(){
return super.getOutputCsvPath();
}

public String[] getFormats(){
return this.formats;
}

public void setFormats(String[] formats){
this.formats = Objects.requireNonNull(formats);
}

protected PMML loadPMML(String learnerPath, String featureMapPath) throws Exception {
Learner learner;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright (c) 2023 Villu Ruusmann
*
* This file is part of JPMML-XGBoost
*
* JPMML-XGBoost is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-XGBoost is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-XGBoost. If not, see <http://www.gnu.org/licenses/>.
*/
package org.jpmml.xgboost.testing;

public interface XGBoostFormats {

String BINARY = "model";
String JSON = "json";
String UBJSON = "ubj";
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import org.jpmml.evaluator.testing.FloatEquivalence;
import org.junit.Test;

public class ClassificationTest extends XGBoostEncoderBatchTest implements XGBoostAlgorithms, XGBoostDatasets, Fields {
public class ClassificationTest extends XGBoostEncoderBatchTest implements XGBoostAlgorithms, XGBoostDatasets, XGBoostFormats, Fields {

public ClassificationTest(){
super(new FloatEquivalence(4));
Expand All @@ -39,6 +39,15 @@ public ClassificationTest(){
public XGBoostEncoderBatch createBatch(String algorithm, String dataset, Predicate<ResultField> columnFilter, Equivalence<Object> equivalence){
XGBoostEncoderBatch result = new XGBoostEncoderBatch(algorithm, dataset, columnFilter, equivalence){

{
String dataset = getDataset();

// XXX
if(dataset.startsWith(AUDIT)){
setFormats(new String[]{JSON, UBJSON});
}
}

@Override
public ClassificationTest getArchiveBatchTest(){
return ClassificationTest.this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,41 @@
*/
package org.jpmml.xgboost.testing;

import java.util.function.Predicate;

import com.google.common.base.Equivalence;
import org.jpmml.evaluator.ResultField;
import org.jpmml.evaluator.testing.FloatEquivalence;
import org.junit.Test;

public class RegressionTest extends XGBoostEncoderBatchTest implements XGBoostAlgorithms, XGBoostDatasets {
public class RegressionTest extends XGBoostEncoderBatchTest implements XGBoostAlgorithms, XGBoostDatasets, XGBoostFormats {

public RegressionTest(){
super(new FloatEquivalence(4));
}

@Override
public XGBoostEncoderBatch createBatch(String algorithm, String dataset, Predicate<ResultField> columnFilter, Equivalence<Object> equivalence){
XGBoostEncoderBatch result = new XGBoostEncoderBatch(algorithm, dataset, columnFilter, equivalence){

{
String dataset = getDataset();

// XXX
if(dataset.startsWith(AUDIT) || dataset.startsWith(AUTO) || dataset.startsWith(VISIT)){
setFormats(new String[]{JSON, UBJSON});
}
}

@Override
public RegressionTest getArchiveBatchTest(){
return RegressionTest.this;
}
};

return result;
}

@Test
public void evaluateLinearAuto() throws Exception {
evaluate(LINEAR_REGRESSION, AUTO, new FloatEquivalence(8));
Expand Down

0 comments on commit 570b56e

Please sign in to comment.