Skip to content

Commit

Permalink
Some changes to xgboost stats computation
Browse files Browse the repository at this point in the history
  • Loading branch information
asprasad committed Jul 5, 2021
1 parent 047a08b commit 8e13700
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
2 changes: 1 addition & 1 deletion model_stat_utils/compute_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def ConstructTreeEnsembleFromXGB(xgboostJSON):


# Args : Model filename, model format (XGBoost, LightGBM)
filename = os.path.join(modelFileDir, "abalone_xgb_model_save.json")
filename = os.path.join(modelFileDir, "year_prediction_msd_xgb_model_save.json")
modelJSON = ReadModelJSONFile(filename)
ensemble = ConstructTreeEnsembleFromXGB(modelJSON)
stats = ensemble.ComputeTreeSizeStatistics()
Expand Down
15 changes: 10 additions & 5 deletions model_stat_utils/decision_tree_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@ def AggregateListOfLists(lst):
for l in lst:
if len(l) > len(agg):
n = len(l) - len(agg)
agg = agg + [0] * n
for j in range(n):
agg.append([])
for i in range(len(l)):
agg[i] += l[i]
return agg
agg[i].append(l[i])

stats = []
for i in range(len(agg)):
stats.append(ComputeListStats(agg[i]))
return stats

class TreeNode:
def __init__(self, nodeType, threshold, featureIndex) -> None:
Expand Down Expand Up @@ -131,7 +136,7 @@ def NumberOfFeaturesUsedToLeaves(self):

def SortedAggregateFeaturesUsesOnPath(self):
featureUsesOnPaths = [leaf.FeatureUsesInPathToRoot() for leaf in self.leaves]
return AggregateListOfLists(featureUsesOnPaths)
return featureUsesOnPaths

class Feature:
def __init__(self, name, type, index) -> None:
Expand Down Expand Up @@ -187,7 +192,7 @@ def AggregateFeatureUses(self):
return aggregateFeatureUses, aggregateSortedUses

def AggregateSortedFeatureUsesOnPath(self):
featureUses = [t.SortedAggregateFeaturesUsesOnPath() for t in self.trees]
featureUses = [uses for t in self.trees for uses in t.SortedAggregateFeaturesUsesOnPath()]
return AggregateListOfLists(featureUses)

def ComputeTreeSizeStatistics(self):
Expand Down

0 comments on commit 8e13700

Please sign in to comment.